Session Jordan_Normal_Form

Theory Missing_Ring

(*  
    Author:      René Thiemann 
                 Akihisa Yamada
    License:     BSD
*)
section ‹Missing Ring›

text ‹This theory contains several lemmas which might be of interest to the Isabelle distribution.›

theory Missing_Ring
imports
  "HOL-Algebra.Ring"
begin

context comm_monoid
begin

lemma finprod_reindex_bij_betw: "bij_betw h S T 
   g  h ` S  carrier G 
   finprod G (λx. g (h x)) S = finprod G g T"
  using finprod_reindex[of g h S] unfolding bij_betw_def by auto

lemma finprod_reindex_bij_witness:
  assumes witness:
    "a. a  S  i (j a) = a"
    "a. a  S  j a  T"
    "b. b  T  j (i b) = b"
    "b. b  T  i b  S"
  assumes eq:
    "a. a  S  h (j a) = g a"
  assumes g: "g  S  carrier G"
  and h: "h  j ` S  carrier G"
  shows "finprod G g S = finprod G h T"
proof -
  have b: "bij_betw j S T"
    using bij_betw_byWitness[where A=S and f=j and f'=i and A'=T] witness by auto
  have fp: "finprod G g S = finprod G (λx. h (j x)) S"
    by (rule finprod_cong, insert eq g, auto)
  show ?thesis
    using finprod_reindex_bij_betw[OF b h] unfolding fp .
qed
end

lemmas (in abelian_monoid) finsum_reindex_bij_witness = add.finprod_reindex_bij_witness

locale csemiring = semiring + comm_monoid R

context cring
begin
sublocale csemiring ..
end

lemma (in comm_monoid) finprod_one': 
  "( a. a  A  f a = 𝟭)  finprod G f A = 𝟭"
  by (induct A rule: infinite_finite_induct, auto)

lemma (in comm_monoid) finprod_split: 
  "finite A  f ` A  carrier G  a  A  finprod G f A = f a  finprod G f (A - {a})"
  by (rule trans[OF trans[OF _ finprod_Un_disjoint[of "{a}" "A - {a}" f]]], auto,
  rule arg_cong[of _ _ "finprod G f"], auto)

lemma (in comm_monoid) finprod_finprod:
  "finite A  finite B  ( a b. a  A   b  B  g a b  carrier G) 
  finprod G (λ a. finprod G (g a) B) A = finprod G (λ (a,b). g a b) (A × B)"
proof (induct A rule: finite_induct)
  case (insert a' A)
  note IH = this
  let ?l = "(ainsert a' A. finprod G (g a) B)"
  let ?r = "(ainsert a' A × B. case a of (a, b)  g a b)"
  have "?l = finprod G (g a') B  (aA. finprod G (g a) B)"
    using IH by simp
  also have "(aA. finprod G (g a) B) = finprod G (λ (a,b). g a b) (A × B)"
    by (rule IH(3), insert IH, auto)
  finally have idl: "?l = finprod G (g a') B  finprod G (λ (a,b). g a b) (A × B)" .
  from IH(2) have "insert a' A × B = {a'} × B  A × B" by auto
  hence "?r = (a{a'} × B  A × B. case a of (a, b)  g a b)" by simp
  also have " = (a{a'} × B. case a of (a, b)  g a b)  (a A × B. case a of (a, b)  g a b)"
    by (rule finprod_Un_disjoint, insert IH, auto)
  also have "(a{a'} × B. case a of (a, b)  g a b) = finprod G (g a') B"
    using IH(4) IH(5)
  proof (induct B rule: finite_induct)
    case (insert b' B)
    note IH = this
    have id: "(a{a'} × B. case a of (a, b)  g a b) = finprod G (g a') B"
      by (rule IH(3)[OF IH(4)], auto)
    have id2: " x F. {a'} × insert x F = insert (a',x) ({a'} × F)" by auto
    have id3: "(ainsert (a', b') ({a'} × B). case a of (a, b)  g a b)
      = g a' b'  (a({a'} × B). case a of (a, b)  g a b)"
      by (rule trans[OF finprod_insert], insert IH, auto)
    show ?case unfolding id2 id3 id
      by (rule sym, rule finprod_insert, insert IH, auto)
  qed simp
  finally have idr: "?r = finprod G (g a') B  (aA × B. case a of (a, b)  g a b)" .
  show ?case unfolding idl idr ..
qed simp

lemma (in comm_monoid) finprod_swap:
  assumes "finite A" "finite B" " a b. a  A   b  B  g a b  carrier G"
  shows "finprod G (λ (b,a). g a b) (B × A) = finprod G (λ (a,b). g a b) (A × B)"
proof -
  have [simp]: "(λ(a, b). (b, a)) ` (A × B) = B × A" by auto
  have [simp]: "(λ x. case case x of (a, b)  (b, a) of (a, b)  g b a) = (λ (a,b). g a b)"
    by (intro ext, auto)
  show ?thesis 
    by (rule trans[OF trans[OF _ finprod_reindex[of "λ (a,b). g b a" "λ (a,b). (b,a)"]]],
    insert assms, auto simp: inj_on_def)
qed

lemma (in comm_monoid) finprod_finprod_swap:
  "finite A  finite B  ( a b. a  A   b  B  g a b  carrier G) 
  finprod G (λ a. finprod G (g a) B) A = finprod G (λ b. finprod G (λ a. g a b) A) B"
  using finprod_finprod[of A B] finprod_finprod[of B A] finprod_swap[of A B]
  by simp



lemmas (in semiring) finsum_zero' = add.finprod_one' 
lemmas (in semiring) finsum_split = add.finprod_split 
lemmas (in semiring) finsum_finsum_swap = add.finprod_finprod_swap


lemma (in csemiring) finprod_zero: 
  "finite A  f  A  carrier R  aA. f a = 𝟬
    finprod R f A = 𝟬"
proof (induct A rule: finite_induct)
  case (insert a A)
  from finprod_insert[OF insert(1-2), of f] insert(4)
  have ins: "finprod R f (insert a A) = f a  finprod R f A" by simp
  have fA: "finprod R f A  carrier R"
    by (rule finprod_closed, insert insert, auto)
  show ?case
  proof (cases "f a = 𝟬")
    case True
    with fA show ?thesis unfolding ins by simp
  next
    case False
    with insert(5) have " a  A. f a = 𝟬" by auto
    from insert(3)[OF _ this] insert have "finprod R f A = 𝟬" by auto
    with insert show ?thesis unfolding ins by auto
  qed
qed simp

lemma (in semiring) finsum_product:
  assumes A: "finite A" and B: "finite B"
  and f: "f  A  carrier R" and g: "g  B  carrier R" 
  shows "finsum R f A  finsum R g B = (iA. jB. f i  g j)"
  unfolding finsum_ldistr[OF A finsum_closed[OF g] f]
proof (rule finsum_cong'[OF refl])
  fix a
  assume a: "a  A"
  show "f a  finsum R g B = (jB. f a  g j)"
  by (rule finsum_rdistr[OF B _ g], insert a f, auto)
qed (insert f g B, auto intro: finsum_closed)
    
lemma (in semiring) Units_one_side_I: 
  "a  carrier R  p  Units R  p  a = 𝟭  a  Units R"
  "a  carrier R  p  Units R  a  p = 𝟭  a  Units R"
  by (metis Units_closed Units_inv_Units Units_l_inv inv_unique)+

context ordered_cancel_semiring begin
subclass ordered_cancel_ab_semigroup_add ..
end

text ‹partially ordered variant›
class ordered_semiring_strict = semiring + comm_monoid_add + ordered_cancel_ab_semigroup_add +
  assumes mult_strict_left_mono: "a < b  0 < c  c * a < c * b"
  assumes mult_strict_right_mono: "a < b  0 < c  a * c < b * c"
begin

subclass semiring_0_cancel ..

subclass ordered_semiring
proof
  fix a b c :: 'a
  assume A: "a  b" "0  c"
  from A show "c * a  c * b"
    unfolding le_less
    using mult_strict_left_mono by (cases "c = 0") auto
  from A show "a * c  b * c"
    unfolding le_less
    using mult_strict_right_mono by (cases "c = 0") auto
qed

lemma mult_pos_pos[simp]: "0 < a  0 < b  0 < a * b"
using mult_strict_left_mono [of 0 b a] by simp

lemma mult_pos_neg: "0 < a  b < 0  a * b < 0"
using mult_strict_left_mono [of b 0 a] by simp

lemma mult_neg_pos: "a < 0  0 < b  a * b < 0"
using mult_strict_right_mono [of a 0 b] by simp

text ‹Legacy - use mult_neg_pos›
lemma mult_pos_neg2: "0 < a  b < 0  b * a < 0" 
by (drule mult_strict_right_mono [of b 0], auto)

text‹Strict monotonicity in both arguments›
lemma mult_strict_mono:
  assumes "a < b" and "c < d" and "0 < b" and "0  c"
  shows "a * c < b * d"
  using assms apply (cases "c=0")
  apply (simp)
  apply (erule mult_strict_right_mono [THEN less_trans])
  apply (force simp add: le_less)
  apply (erule mult_strict_left_mono, assumption)
  done

text‹This weaker variant has more natural premises›
lemma mult_strict_mono':
  assumes "a < b" and "c < d" and "0  a" and "0  c"
  shows "a * c < b * d"
by (rule mult_strict_mono) (insert assms, auto)

lemma mult_less_le_imp_less:
  assumes "a < b" and "c  d" and "0  a" and "0 < c"
  shows "a * c < b * d"
  using assms apply (subgoal_tac "a * c < b * c")
  apply (erule less_le_trans)
  apply (erule mult_left_mono)
  apply simp
  apply (erule mult_strict_right_mono)
  apply assumption
  done

lemma mult_le_less_imp_less:
  assumes "a  b" and "c < d" and "0 < a" and "0  c"
  shows "a * c < b * d"
  using assms apply (subgoal_tac "a * c  b * c")
  apply (erule le_less_trans)
  apply (erule mult_strict_left_mono)
  apply simp
  apply (erule mult_right_mono)
  apply simp
  done

end

class ordered_idom = idom + ordered_semiring_strict +
  assumes zero_less_one [simp]: "0 < 1" begin

subclass semiring_1 ..
subclass comm_ring_1 ..
subclass ordered_ring ..
subclass ordered_comm_semiring by(unfold_locales, fact mult_left_mono)
subclass ordered_ab_semigroup_add ..

lemma of_nat_ge_0[simp]: "of_nat x  0"
proof (induct x)
  case 0 thus ?case by auto
  next case (Suc x)
    hence "0  of_nat x" by auto
    also have "of_nat x < of_nat (Suc x)" by auto
    finally show ?case by auto
qed

lemma of_nat_eq_0[simp]: "of_nat x = 0  x = 0"
proof(induct x,simp)
  case (Suc x)
    have "of_nat (Suc x) > 0" apply(rule le_less_trans[of _ "of_nat x"]) by auto
    thus ?case by auto
qed

lemma inj_of_nat: "inj (of_nat :: nat  'a)"
proof(rule injI)
  fix x y show "of_nat x = of_nat y  x = y"
  proof (induct x arbitrary: y)
    case 0 thus ?case
      proof (induct y)
        case 0 thus ?case by auto
        next case (Suc y)
          hence "of_nat (Suc y) = 0" by auto
          hence "Suc y = 0" unfolding of_nat_eq_0 by auto
          hence False by auto
          thus ?case by auto
      qed
    next case (Suc x)
      thus ?case
      proof (induct y)
        case 0
          hence "of_nat (Suc x) = 0" by auto
          hence "Suc x = 0" unfolding of_nat_eq_0 by auto
          hence False by auto
          thus ?case by auto
        next case (Suc y) thus ?case by auto
      qed
  qed
qed

subclass ring_char_0 by(unfold_locales, fact inj_of_nat)

end

(*
instance linordered_idom ⊆ ordered_semiring_strict by (intro_classes,auto)
instance linordered_idom ⊆ ordered_idom by (intro_classes, auto)
*)

end

Theory Missing_Permutations

(*  
    Author:      René Thiemann 
                 Akihisa Yamada
    License:     BSD
*)
section ‹Missing Permutations›

text ‹This theory provides some definitions and lemmas on permutations which we did not find in the 
  Isabelle distribution.›

theory Missing_Permutations
imports
  Missing_Ring
  "HOL-Combinatorics.Permutations"
begin

definition signof :: "(nat  nat)  'a :: ring_1" where
  "signof p = (if sign p = 1 then 1 else - 1)"

lemma signof_id[simp]: "signof id = 1" "signof (λ x. x) = 1"
  unfolding signof_def sign_id id_def[symmetric] by auto

lemma signof_inv: "finite S  p permutes S  signof (Hilbert_Choice.inv p) = signof p"
  unfolding signof_def using sign_inverse permutation_permutes by metis

lemma signof_pm_one: "signof p  {1, - 1}"
  unfolding signof_def by auto

lemma signof_compose: assumes "p permutes {0..<(n :: nat)}"
  and "q permutes {0 ..<(m :: nat)}"
  shows "signof (p o q) = signof p * signof q"
proof -
  from assms have pp: "permutation p" "permutation q"
    by (auto simp: permutation_permutes)
  show "signof (p o q) = signof p * signof q"
    unfolding signof_def sign_compose[OF pp] 
    by (auto simp: sign_def split: if_splits)
qed

lemma permutes_funcset: "p permutes A  (p ` A  B) = (A  B)"
  by (simp add: permutes_image)

context comm_monoid
begin
lemma finprod_permute:
  assumes p: "p permutes S"
  and f: "f  S  carrier G"
  shows "finprod G f S = finprod G (f  p) S"
proof -
  from p permutes S have "inj p"
    by (rule permutes_inj)
  then have "inj_on p S"
    by (auto intro: subset_inj_on)
  from finprod_reindex[OF _ this, unfolded permutes_image[OF p], OF f]
  show ?thesis unfolding o_def .
qed

lemma finprod_singleton_set[simp]: assumes "f a  carrier G"
  shows "finprod G f {a} = f a"
proof -
  have "finprod G f {a} = f a  finprod G f {}"
    by (rule finprod_insert, insert assms, auto)
  also have " = f a" using assms by auto
  finally show ?thesis .
qed
end

lemmas (in semiring) finsum_permute = add.finprod_permute
lemmas (in semiring) finsum_singleton_set = add.finprod_singleton_set

lemma permutes_less[simp]: assumes p: "p permutes {0..<(n :: nat)}"
  shows "i < n  p i < n" "i < n  Hilbert_Choice.inv p i < n" 
  "p (Hilbert_Choice.inv p i) = i"
  "Hilbert_Choice.inv p (p i) = i"
proof -
  assume i: "i < n"
  show "p i < n" using permutes_in_image[OF p] i by auto
  let ?inv = "Hilbert_Choice.inv p" 
  have "n. ?inv (p n) = n"
      using permutes_inverses[OF p] by simp
  thus "?inv i < n" 
      by (metis (no_types) atLeastLessThan_iff f_inv_into_f inv_into_into le0 permutes_image[OF p] i)
qed (insert permutes_inverses[OF p], auto)
    
context cring
begin

lemma finsum_permutations_inverse: 
  assumes f: "f  {p. p permutes S}  carrier R"
  shows "finsum R f {p. p permutes S} = finsum R (λp. f(Hilbert_Choice.inv p)) {p. p permutes S}"
  (is "?lhs = ?rhs")
proof -
  let ?inv = "Hilbert_Choice.inv"
  let ?S = "{p . p permutes S}"
  have th0: "inj_on ?inv ?S"
  proof (auto simp add: inj_on_def)
    fix q r
    assume q: "q permutes S"
      and r: "r permutes S"
      and qr: "?inv q = ?inv r"
    then have "?inv (?inv q) = ?inv (?inv r)"
      by simp
    with permutes_inv_inv[OF q] permutes_inv_inv[OF r] show "q = r"
      by metis
  qed
  have th1: "?inv ` ?S = ?S"
    using image_inverse_permutations by blast
  have th2: "?rhs = finsum R (f  ?inv) ?S"
    by (simp add: o_def)
  from finsum_reindex[OF _ th0, of f] show ?thesis unfolding th1 th2 using f .
qed

lemma finsum_permutations_compose_right: assumes q: "q permutes S"
  and *: "f  {p. p permutes S}  carrier R"
  shows "finsum R f {p. p permutes S} = finsum R (λp. f(p  q)) {p. p permutes S}"
  (is "?lhs = ?rhs")
proof -
  let ?S = "{p. p permutes S}"
  let ?inv = "Hilbert_Choice.inv"
  have th0: "?rhs = finsum R (f  (λp. p  q)) ?S"
    by (simp add: o_def)
  have th1: "inj_on (λp. p  q) ?S"
  proof (auto simp add: inj_on_def)
    fix p r
    assume "p permutes S"
      and r: "r permutes S"
      and rp: "p  q = r  q"
    then have "p  (q  ?inv q) = r  (q  ?inv q)"
      by (simp add: o_assoc)
    with permutes_surj[OF q, unfolded surj_iff] show "p = r"
      by simp
  qed
  have th3: "(λp. p  q) ` ?S = ?S"
    using image_compose_permutations_right[OF q] by auto
  from finsum_reindex[OF _ th1, of f]
  show ?thesis unfolding th0 th1 th3 using * .
qed

end

text ‹The following lemma is slightly generalized from Determinants.thy in HMA.›

lemma finite_bounded_functions:
  assumes fS: "finite S"
  shows "finite T  finite {f. (i  T. f i  S)  (i. i  T  f i = i)}"
proof (induct T rule: finite_induct)
  case empty
  have th: "{f. i. f i = i} = {id}"
    by auto
  show ?case
    by (auto simp add: th)
next
  case (insert a T)
  let ?f = "λ(y,g) i. if i = a then y else g i"
  let ?S = "?f ` (S × {f. (iT. f i  S)  (i. i  T  f i = i)})"
  have "?S = {f. (i insert a T. f i  S)  (i. i  insert a T  f i = i)}"
    apply (auto simp add: image_iff)
    apply (rule_tac x="x a" in bexI)
    apply (rule_tac x = "λi. if i = a then i else x i" in exI)
    apply (insert insert, auto)
    done
  with finite_imageI[OF finite_cartesian_product[OF fS insert.hyps(3)], of ?f]
  show ?case
    by metis
qed

lemma finite_bounded_functions':
  assumes fS: "finite S"
  shows "finite T  finite {f. (i  T. f i  S)  (i. i  T  f i = j)}"
proof (induct T rule: finite_induct)
  case empty
  have th: "{f. i. f i = j} = {(λ x. j)}"
    by auto
  show ?case
    by (auto simp add: th)
next
  case (insert a T)
  let ?f = "λ(y,g) i. if i = a then y else g i"
  let ?S = "?f ` (S × {f. (iT. f i  S)  (i. i  T  f i = j)})"
  have "?S = {f. (i insert a T. f i  S)  (i. i  insert a T  f i = j)}"
    apply (auto simp add: image_iff)
    apply (rule_tac x="x a" in bexI)
    apply (rule_tac x = "λi. if i = a then j else x i" in exI)
    apply (insert insert, auto)
    done
  with finite_imageI[OF finite_cartesian_product[OF fS insert.hyps(3)], of ?f]
  show ?case
    by metis
qed

context
  fixes A :: "'a set" 
    and B :: "'b set"
    and a_to_b :: "'a  'b"
    and b_to_a :: "'b  'a"
  assumes ab: " a. a  A  a_to_b a  B"
    and ba: " b. b  B  b_to_a b  A"
    and ab_ba: " a. a  A  b_to_a (a_to_b a) = a"
    and ba_ab: " b. b  B  a_to_b (b_to_a b) = b"
begin

qualified lemma permutes_memb: fixes p :: "'b  'b"
  assumes p: "p permutes B"
  and a: "a  A"
  defines "ip  Hilbert_Choice.inv p"
  shows "a  A" "a_to_b a  B" "ip (a_to_b a)  B" "p (a_to_b a)  B" 
    "b_to_a (p (a_to_b a))  A" "b_to_a (ip (a_to_b a))  A"
proof -
  let ?b = "a_to_b a"
  from p have ip: "ip permutes B" unfolding ip_def by (rule permutes_inv)
  note in_ip = permutes_in_image[OF ip]
  note in_p = permutes_in_image[OF p]
  show a: "a  A" by fact
  show b: "?b  B" by (rule ab[OF a])
  show pb: "p ?b  B" unfolding in_p by (rule b)
  show ipb: "ip ?b  B" unfolding in_ip by (rule b)
  show "b_to_a (p ?b)  A" by (rule ba[OF pb])
  show "b_to_a (ip ?b)  A" by (rule ba[OF ipb])
qed

lemma permutes_bij_main: 
  "{p . p permutes A}  (λ p a. if a  A then b_to_a (p (a_to_b a)) else a) ` {p . p permutes B}" 
  (is "?A  ?f ` ?B")
proof 
  note d = permutes_def
  let ?g = "λ q b. if b  B then a_to_b (q (b_to_a b)) else b"
  let ?inv = "Hilbert_Choice.inv"
  fix p
  assume p: "p  ?f ` ?B"
  then obtain q where q: "q permutes B" and p: "p = ?f q" by auto    
  let ?iq = "?inv q"
  from q have iq: "?iq permutes B" by (rule permutes_inv)
  note in_iq = permutes_in_image[OF iq]
  note in_q = permutes_in_image[OF q]
  have qiB: " b. b  B  q (?iq b) = b" using q by (rule permutes_inverses)
  have iqB: " b. b  B  ?iq (q b) = b" using q by (rule permutes_inverses)
  from q[unfolded d] 
  have q1: " b. b  B  q b = b" 
   and q2: " b. ∃!b'. q b' = b" by auto
  note memb = permutes_memb[OF q]
  show "p  ?A" unfolding p d
  proof (rule, intro conjI impI allI, force)
    fix a
    show "∃!a'. ?f q a' = a"
    proof (cases "a  A")
      case True
      note a = memb[OF True]
      let ?a = "b_to_a (?iq (a_to_b a))"
      show ?thesis
      proof 
        show "?f q ?a = a" using a by (simp add: ba_ab qiB ab_ba)
      next
        fix a'
        assume id: "?f q a' = a"
        show "a' = ?a"
        proof (cases "a'  A")
          case False
          thus ?thesis using id a by auto
        next
          case True
          note a' = memb[OF this]
          from id True have "b_to_a (q (a_to_b a')) = a" by simp
          from arg_cong[OF this, of "a_to_b"] a' a
          have "q (a_to_b a') = a_to_b a" by (simp add: ba_ab)
          from arg_cong[OF this, of ?iq]
          have "a_to_b a' = ?iq (a_to_b a)" unfolding iqB[OF a'(2)] .
          from arg_cong[OF this, of b_to_a] show ?thesis unfolding ab_ba[OF True] .
        qed
      qed
    next
      case False note a = this
      show ?thesis
      proof
        show "?f q a = a" using a by simp
      next
        fix a'
        assume id: "?f q a' = a"
        show "a' = a"
        proof (cases "a'  A")
          case False
          with id show ?thesis by simp
        next
          case True
          note a' = memb[OF True]
          with id False show ?thesis by auto
        qed
      qed
    qed
  qed
qed
end

lemma  permutes_bij': assumes ab: " a. a  A  a_to_b a  B"
    and ba: " b. b  B  b_to_a b  A"
    and ab_ba: " a. a  A  b_to_a (a_to_b a) = a"
    and ba_ab: " b. b  B  a_to_b (b_to_a b) = b"
  shows "{p . p permutes A} = (λ p a. if a  A then b_to_a (p (a_to_b a)) else a) ` {p . p permutes B}" 
  (is "?A = ?f ` ?B")
proof -
  note one_dir = ab ba ab_ba ba_ab
  note other_dir = ba ab ba_ab ab_ba
  let ?g = "(λ p b. if b  B then a_to_b (p (b_to_a b)) else b)"
  define PA where "PA = ?A"
  define f where "f = ?f"
  define g where "g = ?g"
  {
    fix p
    assume "p  PA"
    hence p: "p permutes A" unfolding PA_def by simp
    from p[unfolded permutes_def] have pnA: " a. a  A  p a = a" by auto
    have "?f (?g p) = p"
    proof (rule ext)
      fix a
      show "?f (?g p) a = p a"
      proof (cases "a  A")
        case False
        thus ?thesis by (simp add: pnA)
      next
        case True note a = this
        hence "p a  A" unfolding permutes_in_image[OF p] .
        thus ?thesis using a by (simp add: ab_ba ba_ab ab)
      qed
    qed
    hence "f (g p) = p" unfolding f_def g_def .
  }
  hence "f ` g ` PA = PA" by force
  hence id: "?f ` ?g ` ?A = ?A" unfolding PA_def f_def g_def .
  have "?f ` ?B  ?A" by (rule permutes_bij_main[OF one_dir])
  moreover have "?g ` ?A  ?B" by (rule permutes_bij_main[OF ba ab ba_ab ab_ba])
  hence "?f ` ?g ` ?A  ?f ` ?B" by auto
  hence "?A  ?f ` ?B" unfolding id .
  ultimately show ?thesis by blast
qed    

lemma inj_on_nat_permutes: assumes i: "inj_on f (S :: nat set)"
  and fS: "f  S  S"
  and fin: "finite S"
  and f: " i. i  S  f i = i"
  shows "f permutes S"
  unfolding permutes_def
proof (intro conjI allI impI, rule f)
  fix y
  from endo_inj_surj[OF fin _ i] fS have fs: "f ` S = S" by auto
  show "∃!x. f x = y"
  proof (cases "y  S")
    case False
    thus ?thesis by (intro ex1I[of _ y], insert fS f, auto)
  next
    case True
    with fs obtain x where x: "x  S" and fx: "f x = y" by force
    show ?thesis
    proof (rule ex1I, rule fx)
      fix x'
      assume fx': "f x' = y"
      with True f[of x'] have "x'  S" by metis
      from inj_onD[OF i fx[folded fx'] x this]
      show "x' = x" by simp
    qed
  qed
qed


lemma permutes_pair_eq:
  assumes p: "p permutes S"
  shows "{ (p s, s) | s. s  S } = { (s, Hilbert_Choice.inv p s) | s. s  S }"
    (is "?L = ?R")
proof
  show "?L  ?R"
  proof
    fix x assume "x  ?L"
    then obtain s where x: "x = (p s, s)" and s: "s  S" by auto
    note x
    also have "(p s, s) = (p s, Hilbert_Choice.inv p (p s))"
      using permutes_inj[OF p] inv_f_f by auto
    also have "...  ?R" using s permutes_in_image[OF p] by auto
    finally show "x  ?R".
  qed
  show "?R  ?L"
  proof
    fix x assume "x  ?R"
    then obtain s
      where x: "x = (s, Hilbert_Choice.inv p s)" (is "_ = (s, ?ips)")
        and s: "s  S" by auto
    note x
    also have "(s, ?ips) = (p ?ips, ?ips)"
      using inv_f_f[OF permutes_inj[OF permutes_inv[OF p]]]
      using inv_inv_eq[OF permutes_bij[OF p]] by auto
    also have "...  ?L"
      using s permutes_in_image[OF permutes_inv[OF p]] by auto
    finally show "x  ?L".
  qed
qed

lemma inj_on_finite[simp]:
  assumes inj: "inj_on f A" shows "finite (f ` A) = finite A"
proof
  assume fin: "finite (f ` A)"
  show "finite A"
  proof (cases "card (f ` A) = 0")
    case True thus ?thesis using fin by auto
    next case False 
      hence "card A > 0" unfolding card_image[OF inj] by auto
      thus ?thesis using card.infinite by force
  qed
qed auto

lemma permutes_prod:
  assumes p: "p permutes S"
  shows "(sS. f (p s) s) = (sS. f s (Hilbert_Choice.inv p s))"
    (is "?l = ?r")
proof -
  let ?f = "λ(x,y). f x y"
  let ?ps = "λs. (p s, s)"
  let ?ips = "λs. (s, Hilbert_Choice.inv p s)"
  have inj1: "inj_on ?ps S" by (rule inj_onI;auto)
  have inj2: "inj_on ?ips S" by (rule inj_onI;auto)
  have "?l = prod ?f (?ps ` S)"
    using prod.reindex[OF inj1, of ?f] by simp
  also have "?ps ` S = {(p s, s) |s. s  S}" by auto
  also have "... = {(s, Hilbert_Choice.inv p s) | s. s  S}"
    unfolding permutes_pair_eq[OF p] by simp
  also have "... = ?ips ` S" by auto
  also have "prod ?f ... = ?r"
    using prod.reindex[OF inj2, of ?f] by simp
  finally show ?thesis.
qed

lemma permutes_sum:
  assumes p: "p permutes S"
  shows "(sS. f (p s) s) = (sS. f s (Hilbert_Choice.inv p s))"
    (is "?l = ?r")
proof -
  let ?f = "λ(x,y). f x y"
  let ?ps = "λs. (p s, s)"
  let ?ips = "λs. (s, Hilbert_Choice.inv p s)"
  have inj1: "inj_on ?ps S" by (rule inj_onI;auto)
  have inj2: "inj_on ?ips S" by (rule inj_onI;auto)
  have "?l = sum ?f (?ps ` S)"
    using sum.reindex[OF inj1, of ?f] by simp
  also have "?ps ` S = {(p s, s) |s. s  S}" by auto
  also have "... = {(s, Hilbert_Choice.inv p s) | s. s  S}"
    unfolding permutes_pair_eq[OF p] by simp
  also have "... = ?ips ` S" by auto
  also have "sum ?f ... = ?r"
    using sum.reindex[OF inj2, of ?f] by simp
  finally show ?thesis.
qed

lemma inv_inj_on_permutes: "inj_on Hilbert_Choice.inv { p. p permutes S }"
proof (intro inj_onI, unfold mem_Collect_eq)
  let ?i = "Hilbert_Choice.inv"
  fix p q
  assume p: "p permutes S" and q: "q permutes S" and eq: "?i p = ?i q"
  have "?i (?i p) = ?i (?i q)" using eq by simp
  thus "p = q"
    using inv_inv_eq[OF permutes_bij] p q by metis
qed

lemma permutes_others:
  assumes p: "p permutes S" and x: "x  S" shows "p x = x"
  using p unfolding permutes_def using x by simp

end

Theory Conjugate

(*  
    Author:      René Thiemann 
                 Akihisa Yamada
    License:     BSD
*)
theory Conjugate
  imports HOL.Complex
begin

class conjugate =
  fixes conjugate :: "'a  'a"
  assumes conjugate_id[simp]: "conjugate (conjugate a) = a"
      and conjugate_cancel_iff[simp]: "conjugate a = conjugate b  a = b"

class conjugatable_ring = ring + conjugate +
  assumes conjugate_dist_mul: "conjugate (a * b) = conjugate a * conjugate b"
      and conjugate_dist_add: "conjugate (a + b) = conjugate a + conjugate b"
      and conjugate_neg: "conjugate (-a) = - conjugate a"
      and conjugate_zero[simp]: "conjugate 0 = 0"
begin
  lemma conjugate_zero_iff[simp]: "conjugate a = 0  a = 0"
    using conjugate_cancel_iff[of _ 0, unfolded conjugate_zero].
end

class conjugatable_field = conjugatable_ring + field

lemma sum_conjugate:
  fixes f :: "'b  'a :: conjugatable_ring"
  assumes finX: "finite X"
  shows "conjugate (sum f X) = sum (λx. conjugate (f x)) X"
  using finX by (induct set:finite, auto simp: conjugate_dist_add)

class conjugatable_ordered_ring = conjugatable_ring + ordered_comm_monoid_add +
  assumes conjugate_square_positive: "a * conjugate a  0"

class conjugatable_ordered_field = conjugatable_ordered_ring + field
begin
  subclass conjugatable_field..
end

lemma conjugate_square_0:
  fixes a :: "'a :: {conjugatable_ordered_ring, semiring_no_zero_divisors}"
  shows "a * conjugate a = 0  a = 0" by auto


subsection ‹Instantiations›

instantiation complex :: conjugatable_ordered_field
begin
  definition [simp]: "conjugate  cnj"
  definition [simp]: "x < y  Im x = Im y  Re x < Re y"
  definition [simp]: "x  y  Im x = Im y  Re x  Re y"
  
  instance by (intro_classes, auto simp: complex.expand)
end

instantiation real :: conjugatable_ordered_field
begin
  definition [simp]: "conjugate (x::real)  x"
  instance by (intro_classes, auto)
end

instantiation rat :: conjugatable_ordered_field
begin
  definition [simp]: "conjugate (x::rat)  x"
  instance by (intro_classes, auto)
end

instantiation int :: conjugatable_ordered_ring
begin
  definition [simp]: "conjugate (x::int)  x"
  instance by (intro_classes, auto)
end

lemma conjugate_square_eq_0 [simp]:
  fixes x :: "'a :: {conjugatable_ring,semiring_no_zero_divisors}"
  shows "x * conjugate x = 0  x = 0" "conjugate x * x = 0  x = 0"
  by auto

lemma conjugate_square_greater_0 [simp]:
  fixes x :: "'a :: {conjugatable_ordered_ring,ring_no_zero_divisors}"
  shows "x * conjugate x > 0  x  0" 
  using conjugate_square_positive[of x]
  by (auto simp: le_less)

lemma conjugate_square_smaller_0 [simp]:
  fixes x :: "'a :: {conjugatable_ordered_ring,ring_no_zero_divisors}"
  shows "¬ x * conjugate x < 0"
  using conjugate_square_positive[of x] by auto

end

Theory Matrix

(*
    Author:      René Thiemann
                 Akihisa Yamada
    License:     BSD
*)
(* with contributions from Alexander Bentkamp, Universität des Saarlandes *)

section‹Vectors and Matrices›

text ‹We define vectors as pairs of dimension and a characteristic function from natural numbers
to elements.
Similarly, matrices are defined as triples of two dimensions and one
characteristic function from pairs of natural numbers to elements.
Via a subtype we ensure that the characteristic function always behaves the same
on indices outside the intended one. Hence, every matrix has a unique representation.

In this part we define basic operations like matrix-addition, -multiplication, scalar-product,
etc. We connect these operations to HOL-Algebra with its explicit carrier sets.›

theory Matrix
imports
  Missing_Ring
  "HOL-Algebra.Module"
  Polynomial_Interpolation.Ring_Hom
  Conjugate
begin

subsection‹Vectors›

text ‹Here we specify which value should be returned in case
  an index is out of bounds. The current solution has the advantage
  that in the implementation later on, no index comparison has to be performed.›

definition undef_vec :: "nat  'a" where
  "undef_vec i  [] ! i"

definition mk_vec :: "nat  (nat  'a)  (nat  'a)" where
  "mk_vec n f  λ i. if i < n then f i else undef_vec (i - n)"

typedef 'a vec = "{(n, mk_vec n f) | n f :: nat  'a. True}"
  by auto

setup_lifting type_definition_vec

lift_definition dim_vec :: "'a vec  nat" is fst .
lift_definition vec_index :: "'a vec  (nat  'a)" (infixl "$" 100) is snd .
lift_definition vec :: "nat  (nat  'a)  'a vec"
  is "λ n f. (n, mk_vec n f)" by auto

lift_definition vec_of_list :: "'a list  'a vec" is
  "λ v. (length v, mk_vec (length v) (nth v))" by auto

lift_definition list_of_vec :: "'a vec  'a list" is
  "λ (n,v). map v [0 ..< n]" .

definition carrier_vec :: "nat  'a vec set" where
  "carrier_vec n = { v . dim_vec v = n}"

lemma carrier_vec_dim_vec[simp]: "v  carrier_vec (dim_vec v)" unfolding carrier_vec_def by auto

lemma dim_vec[simp]: "dim_vec (vec n f) = n" by transfer simp
lemma vec_carrier[simp]: "vec n f  carrier_vec n" unfolding carrier_vec_def by auto
lemma index_vec[simp]: "i < n  vec n f $ i = f i" by transfer (simp add: mk_vec_def)
lemma eq_vecI[intro]: "( i. i < dim_vec w  v $ i = w $ i)  dim_vec v = dim_vec w
   v = w"
  by (transfer, auto simp: mk_vec_def)

lemma carrier_dim_vec: "v  carrier_vec n  dim_vec v = n"
  unfolding carrier_vec_def by auto

lemma carrier_vecD[simp]: "v  carrier_vec n  dim_vec v = n" using carrier_dim_vec by auto

lemma carrier_vecI: "dim_vec v = n  v  carrier_vec n" using carrier_dim_vec by auto

instantiation vec :: (plus) plus
begin
definition plus_vec :: "'a vec  'a vec  'a :: plus vec" where
  "v1 + v2  vec (dim_vec v2) (λ i. v1 $ i + v2 $ i)"
instance ..
end

instantiation vec :: (minus) minus
begin
definition minus_vec :: "'a vec  'a vec  'a :: minus vec" where
  "v1 - v2  vec (dim_vec v2) (λ i. v1 $ i - v2 $ i)"
instance ..
end

definition
  zero_vec :: "nat  'a :: zero vec" ("0v")
  where "0v n  vec n (λ i. 0)"

lemma zero_carrier_vec[simp]: "0v n  carrier_vec n"
  unfolding zero_vec_def carrier_vec_def by auto

lemma index_zero_vec[simp]: "i < n  0v n $ i = 0" "dim_vec (0v n) = n"
  unfolding zero_vec_def by auto

lemma vec_of_dim_0[simp]: "dim_vec v = 0  v = 0v 0" by auto

definition
  unit_vec :: "nat  nat  ('a :: zero_neq_one) vec"
  where "unit_vec n i = vec n (λ j. if j = i then 1 else 0)"

lemma index_unit_vec[simp]:
  "i < n  j < n  unit_vec n i $ j = (if j = i then 1 else 0)"
  "i < n  unit_vec n i $ i = 1"
  "dim_vec (unit_vec n i) = n"
  unfolding unit_vec_def by auto

lemma unit_vec_eq[simp]:
  assumes i: "i < n"
  shows "(unit_vec n i = unit_vec n j) = (i = j)"
proof -
  have "i  j  unit_vec n i $ i  unit_vec n j $ i"
    unfolding unit_vec_def using i by simp
  then show ?thesis by metis
qed

lemma unit_vec_nonzero[simp]:
  assumes i_n: "i < n" shows "unit_vec n i  zero_vec n" (is "?l  ?r")
proof -
  have "?l $ i = 1" "?r $ i = 0" using i_n by auto
  thus "?l  ?r" by auto
qed

lemma unit_vec_carrier[simp]: "unit_vec n i  carrier_vec n"
  unfolding unit_vec_def carrier_vec_def by auto

definition unit_vecs:: "nat  'a :: zero_neq_one vec list"
  where "unit_vecs n = map (unit_vec n) [0..<n]"

text "List of first i units"

fun unit_vecs_first:: "nat  nat  'a::zero_neq_one vec list"
  where "unit_vecs_first n 0 = []"
    |   "unit_vecs_first n (Suc i) = unit_vecs_first n i @ [unit_vec n i]"

lemma unit_vecs_first: "unit_vecs n = unit_vecs_first n n"
  unfolding unit_vecs_def set_map set_upt
proof -
  {fix m
    have "m  n  map (unit_vec n) [0..<m] = unit_vecs_first n m"
    proof (induct m)
      case (Suc m) then have mn:"mn" by auto
        show ?case unfolding upt_Suc using Suc(1)[OF mn] by auto
    qed auto
  }
  thus "map (unit_vec n) [0..<n] = unit_vecs_first n n" by auto
qed

text "list of last i units"

fun unit_vecs_last:: "nat  nat  'a :: zero_neq_one vec list"
  where "unit_vecs_last n 0 = []"
    |   "unit_vecs_last n (Suc i) = unit_vec n (n - Suc i) # unit_vecs_last n i"

lemma unit_vecs_last_carrier: "set (unit_vecs_last n i)  carrier_vec n"
  by (induct i;auto)

lemma unit_vecs_last[code]: "unit_vecs n = unit_vecs_last n n"
proof -
  { fix m assume "m = n"
    have "m  n  map (unit_vec n) [n-m..<n] = unit_vecs_last n m"
      proof (induction m)
      case (Suc m)
        then have nm:"n - Suc m < n" by auto
        have ins: "[n - Suc m ..< n] = (n - Suc m) # [n - m ..< n]"
          unfolding upt_conv_Cons[OF nm]
          by (auto simp: Suc.prems Suc_diff_Suc Suc_le_lessD)
        show ?case
          unfolding ins
          unfolding unit_vecs_last.simps
          unfolding list.map
          using Suc
          unfolding Suc by auto
      qed simp
  }
  thus "unit_vecs n = unit_vecs_last n n"
    unfolding unit_vecs_def by auto
qed

lemma unit_vecs_carrier: "set (unit_vecs n)  carrier_vec n"
proof
  fix u :: "'a vec"  assume u: "u  set (unit_vecs n)"
  then obtain i where "u = unit_vec n i" unfolding unit_vecs_def by auto
  then show "u  carrier_vec n"
    using unit_vec_carrier by auto
qed

lemma unit_vecs_last_distinct:
  "j  n  i < n - j  unit_vec n i  set (unit_vecs_last n j)"
  by (induction j arbitrary:i, auto)

lemma unit_vecs_first_distinct:
  "i  j  j < n  unit_vec n j  set (unit_vecs_first n i)"
  by (induction i arbitrary:j, auto)

definition map_vec where "map_vec f v  vec (dim_vec v) (λi. f (v $ i))"

instantiation vec :: (uminus) uminus
begin
definition uminus_vec :: "'a :: uminus vec  'a vec" where
  "- v  vec (dim_vec v) (λ i. - (v $ i))"
instance ..
end

definition smult_vec :: "'a :: times  'a vec  'a vec" (infixl "v" 70)
  where "a v v  vec (dim_vec v) (λ i. a * v $ i)"

definition scalar_prod :: "'a vec  'a vec  'a :: semiring_0" (infix "" 70)
  where "v  w   i  {0 ..< dim_vec w}. v $ i * w $ i"

definition monoid_vec :: "'a itself  nat  ('a :: monoid_add vec) monoid" where
  "monoid_vec ty n  
    carrier = carrier_vec n,
    mult = (+),
    one = 0v n"

definition module_vec ::
  "'a :: semiring_1 itself  nat  ('a,'a vec) module" where
  "module_vec ty n  
    carrier = carrier_vec n,
    mult = undefined,
    one = undefined,
    zero = 0v n,
    add = (+),
    smult = (⋅v)"

lemma monoid_vec_simps:
  "mult (monoid_vec ty n) = (+)"
  "carrier (monoid_vec ty n) = carrier_vec n"
  "one (monoid_vec ty n) = 0v n"
  unfolding monoid_vec_def by auto

lemma module_vec_simps:
  "add (module_vec ty n) = (+)"
  "zero (module_vec ty n) = 0v n"
  "carrier (module_vec ty n) = carrier_vec n"
  "smult (module_vec ty n) = (⋅v)"
  unfolding module_vec_def by auto

definition finsum_vec :: "'a :: monoid_add itself  nat  ('c  'a vec)  'c set  'a vec" where
  "finsum_vec ty n = finprod (monoid_vec ty n)"

lemma index_add_vec[simp]:
  "i < dim_vec v2  (v1 + v2) $ i = v1 $ i + v2 $ i" "dim_vec (v1 + v2) = dim_vec v2"
  unfolding plus_vec_def by auto

lemma index_minus_vec[simp]:
  "i < dim_vec v2  (v1 - v2) $ i = v1 $ i - v2 $ i" "dim_vec (v1 - v2) = dim_vec v2"
  unfolding minus_vec_def by auto

lemma index_map_vec[simp]:
  "i < dim_vec v  map_vec f v $ i = f (v $ i)"
  "dim_vec (map_vec f v) = dim_vec v"
  unfolding map_vec_def by auto

lemma map_carrier_vec[simp]: "map_vec h v  carrier_vec n = (v  carrier_vec n)"
  unfolding map_vec_def carrier_vec_def by auto

lemma index_uminus_vec[simp]:
  "i < dim_vec v  (- v) $ i = - (v $ i)"
  "dim_vec (- v) = dim_vec v"
  unfolding uminus_vec_def by auto

lemma index_smult_vec[simp]:
  "i < dim_vec v  (a v v) $ i = a * v $ i" "dim_vec (a v v) = dim_vec v"
  unfolding smult_vec_def by auto

lemma add_carrier_vec[simp]:
  "v1  carrier_vec n  v2  carrier_vec n  v1 + v2  carrier_vec n"
  unfolding carrier_vec_def by auto

lemma minus_carrier_vec[simp]:
  "v1  carrier_vec n  v2  carrier_vec n  v1 - v2  carrier_vec n"
  unfolding carrier_vec_def by auto

lemma comm_add_vec[ac_simps]:
  "(v1 :: 'a :: ab_semigroup_add vec)  carrier_vec n  v2  carrier_vec n  v1 + v2 = v2 + v1"
  by (intro eq_vecI, auto simp: ac_simps)

lemma assoc_add_vec[simp]:
  "(v1 :: 'a :: semigroup_add vec)  carrier_vec n  v2  carrier_vec n  v3  carrier_vec n
   (v1 + v2) + v3 = v1 + (v2 + v3)"
  by (intro eq_vecI, auto simp: ac_simps)

lemma zero_minus_vec[simp]: "(v :: 'a :: group_add vec)  carrier_vec n  0v n - v = - v"
  by (intro eq_vecI, auto)

lemma minus_zero_vec[simp]: "(v :: 'a :: group_add vec)  carrier_vec n  v - 0v n = v"
  by (intro eq_vecI, auto)

lemma minus_cancel_vec[simp]: "(v :: 'a :: group_add vec)  carrier_vec n  v - v = 0v n"
  by (intro eq_vecI, auto)

lemma minus_add_uminus_vec: "(v :: 'a :: group_add vec)  carrier_vec n 
  w  carrier_vec n  v - w = v + (- w)"
  by (intro eq_vecI, auto)

lemma comm_monoid_vec: "comm_monoid (monoid_vec TYPE ('a :: comm_monoid_add) n)"
  by (unfold_locales, auto simp: monoid_vec_def ac_simps)

lemma left_zero_vec[simp]: "(v :: 'a :: monoid_add vec)  carrier_vec n   0v n + v = v" by auto

lemma right_zero_vec[simp]: "(v :: 'a :: monoid_add vec)  carrier_vec n   v + 0v n = v" by auto


lemma uminus_carrier_vec[simp]:
  "(- v  carrier_vec n) = (v  carrier_vec n)"
  unfolding carrier_vec_def by auto

lemma uminus_r_inv_vec[simp]:
  "(v :: 'a :: group_add vec)  carrier_vec n  (v + - v) = 0v n"
  by (intro eq_vecI, auto)

lemma uminus_l_inv_vec[simp]:
  "(v :: 'a :: group_add vec)  carrier_vec n  (- v + v) = 0v n"
  by (intro eq_vecI, auto)

lemma add_inv_exists_vec:
  "(v :: 'a :: group_add vec)  carrier_vec n   w  carrier_vec n. w + v = 0v n  v + w = 0v n"
  by (intro bexI[of _ "- v"], auto)

lemma comm_group_vec: "comm_group (monoid_vec TYPE ('a :: ab_group_add) n)"
  by (unfold_locales, insert add_inv_exists_vec, auto simp: monoid_vec_def ac_simps Units_def)

lemmas finsum_vec_insert =
  comm_monoid.finprod_insert[OF comm_monoid_vec, folded finsum_vec_def, unfolded monoid_vec_simps]

lemmas finsum_vec_closed =
  comm_monoid.finprod_closed[OF comm_monoid_vec, folded finsum_vec_def, unfolded monoid_vec_simps]

lemmas finsum_vec_empty =
  comm_monoid.finprod_empty[OF comm_monoid_vec, folded finsum_vec_def, unfolded monoid_vec_simps]

lemma smult_carrier_vec[simp]: "(a v v  carrier_vec n) = (v  carrier_vec n)"
  unfolding carrier_vec_def by auto

lemma scalar_prod_left_zero[simp]: "v  carrier_vec n  0v n  v = 0"
  unfolding scalar_prod_def
  by (rule sum.neutral, auto)

lemma scalar_prod_right_zero[simp]: "v  carrier_vec n  v  0v n = 0"
  unfolding scalar_prod_def
  by (rule sum.neutral, auto)

lemma scalar_prod_left_unit[simp]: assumes v: "(v :: 'a :: semiring_1 vec)  carrier_vec n" and i: "i < n"
  shows "unit_vec n i  v = v $ i"
proof -
  let ?f = "λ k. unit_vec n i $ k * v $ k"
  have id: "(k{0..<n}. ?f k) = unit_vec n i $ i * v $ i + (k{0..<n} - {i}. ?f k)"
    by (rule sum.remove, insert i, auto)
  also have "( k{0..<n} - {i}. ?f k) = 0"
    by (rule sum.neutral, insert i, auto)
  finally
  show ?thesis unfolding scalar_prod_def using i v by simp
qed

lemma scalar_prod_right_unit[simp]: assumes i: "i < n"
  shows "(v :: 'a :: semiring_1 vec)  unit_vec n i = v $ i"
proof -
  let ?f = "λ k. v $ k * unit_vec n i $ k"
  have id: "(k{0..<n}. ?f k) = v $ i * unit_vec n i $ i + (k{0..<n} - {i}. ?f k)"
    by (rule sum.remove, insert i, auto)
  also have "(k{0..<n} - {i}. ?f k) = 0"
    by (rule sum.neutral, insert i, auto)
  finally
  show ?thesis unfolding scalar_prod_def using i by simp
qed

lemma add_scalar_prod_distrib: assumes v: "v1  carrier_vec n" "v2  carrier_vec n" "v3  carrier_vec n"
  shows "(v1 + v2)  v3 = v1  v3 + v2  v3"
proof -
  have "(i{0..<dim_vec v3}. (v1 + v2) $ i * v3 $ i) = (i{0..<dim_vec v3}. v1 $ i * v3 $ i + v2 $ i * v3 $ i)"
    by (rule sum.cong, insert v, auto simp: algebra_simps)
  thus ?thesis unfolding scalar_prod_def using v by (auto simp: sum.distrib)
qed

lemma scalar_prod_add_distrib: assumes v: "v1  carrier_vec n" "v2  carrier_vec n" "v3  carrier_vec n"
  shows "v1  (v2 + v3) = v1  v2 + v1  v3"
proof -
  have "(i{0..<dim_vec v3}. v1 $ i * (v2 + v3) $ i) = (i{0..<dim_vec v3}. v1 $ i * v2 $ i + v1 $ i * v3 $ i)"
    by (rule sum.cong, insert v, auto simp: algebra_simps)
  thus ?thesis unfolding scalar_prod_def using v by (auto intro: sum.distrib)
qed

lemma smult_scalar_prod_distrib[simp]: assumes v: "v1  carrier_vec n" "v2  carrier_vec n"
  shows "(a v v1)  v2 = a * (v1  v2)"
  unfolding scalar_prod_def sum_distrib_left
  by (rule sum.cong, insert v, auto simp: ac_simps)

lemma scalar_prod_smult_distrib[simp]: assumes v: "v1  carrier_vec n" "v2  carrier_vec n"
  shows "v1  (a v v2) = (a :: 'a :: comm_ring) * (v1  v2)"
  unfolding scalar_prod_def sum_distrib_left
  by (rule sum.cong, insert v, auto simp: ac_simps)

lemma comm_scalar_prod: assumes "(v1 :: 'a :: comm_semiring_0 vec)  carrier_vec n" "v2  carrier_vec n"
  shows "v1  v2 = v2  v1"
  unfolding scalar_prod_def
  by (rule sum.cong, insert assms, auto simp: ac_simps)

lemma add_smult_distrib_vec:
  "((a::'a::ring) + b) v v = a v v + b v v"
  unfolding smult_vec_def plus_vec_def
  by (rule eq_vecI, auto simp: distrib_right)

lemma smult_add_distrib_vec:
  assumes "v  carrier_vec n" "w  carrier_vec n"
  shows "(a::'a::ring) v (v + w) = a v v + a v w"
  apply (rule eq_vecI)
  unfolding smult_vec_def plus_vec_def
  using assms distrib_left by auto

lemma smult_smult_assoc:
  "a v (b v v) = (a * b::'a::ring) v v"
  apply (rule sym, rule eq_vecI)
  unfolding smult_vec_def plus_vec_def using mult.assoc by auto

lemma one_smult_vec [simp]:
  "(1::'a::ring_1) v v = v" unfolding smult_vec_def
  by (rule eq_vecI,auto)

lemma uminus_zero_vec[simp]: "- (0v n) = (0v n :: 'a :: group_add vec)" 
  by (intro eq_vecI, auto)

lemma index_finsum_vec: assumes "finite F" and i: "i < n"
  and vs: "vs  F  carrier_vec n"
  shows "finsum_vec TYPE('a :: comm_monoid_add) n vs F $ i = sum (λ f. vs f $ i) F"
  using ‹finite F vs
proof (induct F)
  case (insert f F)
  hence IH: "finsum_vec TYPE('a) n vs F $ i = (fF. vs f $ i)"
    and vs: "vs  F  carrier_vec n" "vs f  carrier_vec n" by auto
  show ?case unfolding finsum_vec_insert[OF insert(1-2) vs]
    unfolding sum.insert[OF insert(1-2)]
    unfolding IH[symmetric]
    by (rule index_add_vec, insert i, insert finsum_vec_closed[OF vs(1)], auto)
qed (insert i, auto simp: finsum_vec_empty)

text ‹Definition of pointwise ordering on vectors for non-strict part, and
  strict version is defined in a way such that the @{class order} constraints are satisfied.›

instantiation vec :: (ord) ord
begin

definition less_eq_vec :: "'a vec  'a vec  bool" where
  "less_eq_vec v w = (dim_vec v = dim_vec w  ( i < dim_vec w. v $ i  w $ i))" 

definition less_vec :: "'a vec  'a vec  bool" where
  "less_vec v w = (v  w  ¬ (w  v))"
instance ..
end

instantiation vec :: (preorder) preorder
begin
instance
  by (standard, auto simp: less_vec_def less_eq_vec_def order_trans)
end

instantiation vec :: (order) order
begin
instance
  by (standard, intro eq_vecI, auto simp: less_eq_vec_def order.antisym)
end


subsection‹Matrices›

text ‹Similarly as for vectors, we specify which value should be returned in case
  an index is out of bounds. It is defined in a way that only few
  index comparisons have to be performed in the implementation.›

definition undef_mat :: "nat  nat  (nat × nat  'a)  nat × nat  'a" where
  "undef_mat nr nc f  λ (i,j). [[f (i,j). j <- [0 ..< nc]] . i <- [0 ..< nr]] ! i ! j"

lemma undef_cong_mat: assumes " i j. i < nr  j < nc  f (i,j) = f' (i,j)"
  shows "undef_mat nr nc f x = undef_mat nr nc f' x"
proof (cases x)
  case (Pair i j)
  have nth_map_ge: " i xs. ¬ i < length xs  xs ! i = [] ! (i - length xs)"
    by (metis append_Nil2 nth_append)
  note [simp] = Pair undef_mat_def nth_map_ge[of i] nth_map_ge[of j]
  show ?thesis
    by (cases "i < nr", simp, cases "j < nc", insert assms, auto)
qed

definition mk_mat :: "nat  nat  (nat × nat  'a)  (nat × nat  'a)" where
  "mk_mat nr nc f  λ (i,j). if i < nr  j < nc then f (i,j) else undef_mat nr nc f (i,j)"

lemma cong_mk_mat: assumes " i j. i < nr  j < nc  f (i,j) = f' (i,j)"
  shows "mk_mat nr nc f = mk_mat nr nc f'"
  using undef_cong_mat[of nr nc f f', OF assms]
  using assms unfolding mk_mat_def
  by auto

typedef 'a mat = "{(nr, nc, mk_mat nr nc f) | nr nc f :: nat × nat  'a. True}"
  by auto

setup_lifting type_definition_mat

lift_definition dim_row :: "'a mat  nat" is fst .
lift_definition dim_col :: "'a mat  nat" is "fst o snd" .
lift_definition index_mat :: "'a mat  (nat × nat  'a)" (infixl "$$" 100) is "snd o snd" .
lift_definition mat :: "nat  nat  (nat × nat  'a)  'a mat"
  is "λ nr nc f. (nr, nc, mk_mat nr nc f)" by auto
lift_definition mat_of_row_fun :: "nat  nat  (nat  'a vec)  'a mat" ("matr")
  is "λ nr nc f. (nr, nc, mk_mat nr nc (λ (i,j). f i $ j))" by auto

definition mat_to_list :: "'a mat  'a list list" where
  "mat_to_list A = [ [A $$ (i,j) . j <- [0 ..< dim_col A]] . i <- [0 ..< dim_row A]]"

fun square_mat :: "'a mat  bool" where "square_mat A = (dim_col A = dim_row A)"

definition upper_triangular :: "'a::zero mat  bool"
  where "upper_triangular A 
    i < dim_row A.  j < i. A $$ (i,j) = 0"

lemma upper_triangularD[elim] :
  "upper_triangular A  j < i  i < dim_row A  A $$ (i,j) = 0"
unfolding upper_triangular_def by auto

lemma upper_triangularI[intro] :
  "(i j. j < i  i < dim_row A  A $$ (i,j) = 0)  upper_triangular A"
unfolding upper_triangular_def by auto

lemma dim_row_mat[simp]: "dim_row (mat nr nc f) = nr" "dim_row (matr nr nc g) = nr"
  by (transfer, simp)+

lemma dim_col_mat[simp]: "dim_col (mat nr nc f) = nc" "dim_col (matr nr nc g) = nc"
  by (transfer, simp)+

definition carrier_mat :: "nat  nat  'a mat set"
  where "carrier_mat nr nc = { m . dim_row m = nr  dim_col m = nc}"

lemma carrier_mat_triv[simp]: "m  carrier_mat (dim_row m) (dim_col m)"
  unfolding carrier_mat_def by auto

lemma mat_carrier[simp]: "mat nr nc f  carrier_mat nr nc"
  unfolding carrier_mat_def by auto

definition elements_mat :: "'a mat  'a set"
  where "elements_mat A = set [A $$ (i,j). i <- [0 ..< dim_row A], j <- [0 ..< dim_col A]]"

lemma elements_matD [dest]:
  "a  elements_mat A  i j. i < dim_row A  j < dim_col A  a = A $$ (i,j)"
  unfolding elements_mat_def by force

lemma elements_matI [intro]:
  "A  carrier_mat nr nc  i < nr  j < nc  a = A $$ (i,j)  a  elements_mat A"
  unfolding elements_mat_def carrier_mat_def by force

lemma index_mat[simp]:  "i < nr  j < nc  mat nr nc f $$ (i,j) = f (i,j)"
  "i < nr  j < nc  matr nr nc g $$ (i,j) = g i $ j"
  by (transfer', simp add: mk_mat_def)+

lemma eq_matI[intro]: "( i j . i < dim_row B  j < dim_col B  A $$ (i,j) = B $$ (i,j))
   dim_row A = dim_row B
   dim_col A = dim_col B
   A = B"
  by (transfer, auto intro!: cong_mk_mat, auto simp: mk_mat_def)

lemma carrier_matI[intro]:
  assumes "dim_row A = nr" "dim_col A = nc" shows  "A  carrier_mat nr nc"
  using assms unfolding carrier_mat_def by auto

lemma carrier_matD[dest,simp]: assumes "A  carrier_mat nr nc"
  shows "dim_row A = nr" "dim_col A = nc" using assms
  unfolding carrier_mat_def by auto

lemma cong_mat: assumes "nr = nr'" "nc = nc'" " i j. i < nr  j < nc 
  f (i,j) = f' (i,j)" shows "mat nr nc f = mat nr' nc' f'"
  by (rule eq_matI, insert assms, auto)

definition row :: "'a mat  nat  'a vec" where
  "row A i = vec (dim_col A) (λ j. A $$ (i,j))"

definition rows :: "'a mat  'a vec list" where
  "rows A = map (row A) [0..<dim_row A]"

lemma row_carrier[simp]: "row A i  carrier_vec (dim_col A)" unfolding row_def by auto

lemma rows_carrier[simp]: "set (rows A)  carrier_vec (dim_col A)" unfolding rows_def by auto

lemma length_rows[simp]: "length (rows A) = dim_row A" unfolding rows_def by auto

lemma nth_rows[simp]: "i < dim_row A  rows A ! i = row A i"
  unfolding rows_def by auto

lemma row_mat_of_row_fun[simp]: "i < nr  dim_vec (f i) = nc  row (matr nr nc f) i = f i"
  by (rule eq_vecI, auto simp: row_def)

lemma set_rows_carrier:
  assumes "A  carrier_mat m n" and "v  set (rows A)" shows "v  carrier_vec n"
  using assms by (auto simp: rows_def row_def)

definition mat_of_rows :: "nat  'a vec list  'a mat"
  where "mat_of_rows n rs = mat (length rs) n (λ(i,j). rs ! i $ j)"

definition mat_of_rows_list :: "nat  'a list list  'a mat" where
  "mat_of_rows_list nc rs = mat (length rs) nc (λ (i,j). rs ! i ! j)"

lemma mat_of_rows_carrier[simp]:
  "mat_of_rows n vs  carrier_mat (length vs) n"
  "dim_row (mat_of_rows n vs) = length vs"
  "dim_col (mat_of_rows n vs) = n"
  unfolding mat_of_rows_def by auto

lemma mat_of_rows_row[simp]:
  assumes i:"i < length vs" and n: "vs ! i  carrier_vec n"
  shows "row (mat_of_rows n vs) i = vs ! i"
  unfolding mat_of_rows_def row_def using n i by auto

lemma rows_mat_of_rows[simp]:
  assumes "set vs  carrier_vec n" shows "rows (mat_of_rows n vs) = vs"
  unfolding rows_def apply (rule nth_equalityI)
  using assms unfolding subset_code(1) by auto

lemma mat_of_rows_rows[simp]:
  "mat_of_rows (dim_col A) (rows A) = A"
  unfolding mat_of_rows_def by (rule, auto simp: row_def)


definition col :: "'a mat  nat  'a vec" where
  "col A j = vec (dim_row A) (λ i. A $$ (i,j))"

definition cols :: "'a mat  'a vec list" where
  "cols A = map (col A) [0..<dim_col A]"

definition mat_of_cols :: "nat  'a vec list  'a mat"
  where "mat_of_cols n cs = mat n (length cs) (λ(i,j). cs ! j $ i)"

definition mat_of_cols_list :: "nat  'a list list  'a mat" where
  "mat_of_cols_list nr cs = mat nr (length cs) (λ (i,j). cs ! j ! i)"

lemma col_dim[simp]: "col A i  carrier_vec (dim_row A)" unfolding col_def by auto

lemma dim_col[simp]: "dim_vec (col A i) = dim_row A" by auto

lemma cols_dim[simp]: "set (cols A)  carrier_vec (dim_row A)" unfolding cols_def by auto

lemma cols_length[simp]: "length (cols A) = dim_col A" unfolding cols_def by auto

lemma cols_nth[simp]: "i < dim_col A  cols A ! i = col A i"
  unfolding cols_def by auto

lemma mat_of_cols_carrier[simp]:
  "mat_of_cols n vs  carrier_mat n (length vs)"
  "dim_row (mat_of_cols n vs) = n"
  "dim_col (mat_of_cols n vs) = length vs"
  unfolding mat_of_cols_def by auto

lemma col_mat_of_cols[simp]:
  assumes j:"j < length vs" and n: "vs ! j  carrier_vec n"
  shows "col (mat_of_cols n vs) j = vs ! j"
  unfolding mat_of_cols_def col_def using j n by auto

lemma cols_mat_of_cols[simp]:
  assumes "set vs  carrier_vec n" shows "cols (mat_of_cols n vs) = vs"
  unfolding cols_def apply(rule nth_equalityI)
  using assms unfolding subset_code(1) by auto

lemma mat_of_cols_cols[simp]:
  "mat_of_cols (dim_row A) (cols A) = A"
  unfolding mat_of_cols_def by (rule, auto simp: col_def)


instantiation mat :: (ord) ord
begin

definition less_eq_mat :: "'a mat  'a mat  bool" where
  "less_eq_mat A B = (dim_row A = dim_row B  dim_col A = dim_col B  
      ( i < dim_row B.  j < dim_col B. A $$ (i,j)  B $$ (i,j)))" 

definition less_mat :: "'a mat  'a mat  bool" where
  "less_mat A B = (A  B  ¬ (B  A))"
instance ..
end

instantiation mat :: (preorder) preorder
begin
instance
proof (standard, auto simp: less_mat_def less_eq_mat_def, goal_cases)
  case (1 A B C i j)
  thus ?case using order_trans[of "A $$ (i,j)" "B $$ (i,j)" "C $$ (i,j)"] by auto
qed
end

instantiation mat :: (order) order
begin
instance
  by (standard, intro eq_matI, auto simp: less_eq_mat_def order.antisym)
end

instantiation mat :: (plus) plus
begin
definition plus_mat :: "('a :: plus) mat  'a mat  'a mat" where
  "A + B  mat (dim_row B) (dim_col B) (λ ij. A $$ ij + B $$ ij)"
instance ..
end

definition map_mat :: "('a  'b)  'a mat  'b mat" where
  "map_mat f A  mat (dim_row A) (dim_col A) (λ ij. f (A $$ ij))"

definition smult_mat :: "'a :: times  'a mat  'a mat" (infixl "m" 70)
  where "a m A  map_mat (λ b. a * b) A"

definition zero_mat :: "nat  nat  'a :: zero mat" ("0m") where
  "0m nr nc  mat nr nc (λ ij. 0)"

lemma elements_0_mat [simp]: "elements_mat (0m nr nc)  {0}"
  unfolding elements_mat_def zero_mat_def by auto

definition transpose_mat :: "'a mat  'a mat" where
  "transpose_mat A  mat (dim_col A) (dim_row A) (λ (i,j). A $$ (j,i))"

definition one_mat :: "nat  'a :: {zero,one} mat" ("1m") where
  "1m n  mat n n (λ (i,j). if i = j then 1 else 0)"

instantiation mat :: (uminus) uminus
begin
definition uminus_mat :: "'a :: uminus mat  'a mat" where
  "- A  mat (dim_row A) (dim_col A) (λ ij. - (A $$ ij))"
instance ..
end

instantiation mat :: (minus) minus
begin
definition minus_mat :: "('a :: minus) mat  'a mat  'a mat" where
  "A - B  mat (dim_row B) (dim_col B) (λ ij. A $$ ij - B $$ ij)"
instance ..
end

instantiation mat :: (semiring_0) times
begin
definition times_mat :: "'a :: semiring_0 mat  'a mat  'a mat"
  where "A * B  mat (dim_row A) (dim_col B) (λ (i,j). row A i  col B j)"
instance ..
end

definition mult_mat_vec :: "'a :: semiring_0 mat  'a vec  'a vec" (infixl "*v" 70)
  where "A *v v  vec (dim_row A) (λ i. row A i  v)"

definition inverts_mat :: "'a :: semiring_1 mat  'a mat  bool" where
  "inverts_mat A B  A * B = 1m (dim_row A)"

definition invertible_mat :: "'a :: semiring_1 mat  bool"
  where "invertible_mat A  square_mat A  (B. inverts_mat A B  inverts_mat B A)"

definition monoid_mat :: "'a :: monoid_add itself  nat  nat  'a mat monoid" where
  "monoid_mat ty nr nc  
    carrier = carrier_mat nr nc,
    mult = (+),
    one = 0m nr nc"

definition ring_mat :: "'a :: semiring_1 itself  nat  'b  ('a mat,'b) ring_scheme" where
  "ring_mat ty n b  
    carrier = carrier_mat n n,
    mult = (*),
    one = 1m n,
    zero = 0m n n,
    add = (+),
     = b"

definition module_mat :: "'a :: semiring_1 itself  nat  nat  ('a,'a mat)module" where
  "module_mat ty nr nc  
    carrier = carrier_mat nr nc,
    mult = (*),
    one = 1m nr,
    zero = 0m nr nc,
    add = (+),
    smult = (⋅m)"

lemma ring_mat_simps:
  "mult (ring_mat ty n b) = (*)"
  "add (ring_mat ty n b) = (+)"
  "one (ring_mat ty n b) = 1m n"
  "zero (ring_mat ty n b) = 0m n n"
  "carrier (ring_mat ty n b) = carrier_mat n n"
  unfolding ring_mat_def by auto

lemma module_mat_simps:
  "mult (module_mat ty nr nc) = (*)"
  "add (module_mat ty nr nc) = (+)"
  "one (module_mat ty nr nc) = 1m nr"
  "zero (module_mat ty nr nc) = 0m nr nc"
  "carrier (module_mat ty nr nc) = carrier_mat nr nc"
  "smult (module_mat ty nr nc) = (⋅m)"
  unfolding module_mat_def by auto

lemma index_zero_mat[simp]: "i < nr  j < nc  0m nr nc $$ (i,j) = 0"
  "dim_row (0m nr nc) = nr" "dim_col (0m nr nc) = nc"
  unfolding zero_mat_def by auto

lemma index_one_mat[simp]: "i < n  j < n  1m n $$ (i,j) = (if i = j then 1 else 0)"
  "dim_row (1m n) = n" "dim_col (1m n) = n"
  unfolding one_mat_def by auto

lemma index_add_mat[simp]:
  "i < dim_row B  j < dim_col B  (A + B) $$ (i,j) = A $$ (i,j) + B $$ (i,j)"
  "dim_row (A + B) = dim_row B" "dim_col (A + B) = dim_col B"
  unfolding plus_mat_def by auto

lemma index_minus_mat[simp]:
  "i < dim_row B  j < dim_col B  (A - B) $$ (i,j) = A $$ (i,j) - B $$ (i,j)"
  "dim_row (A - B) = dim_row B" "dim_col (A - B) = dim_col B"
  unfolding minus_mat_def by auto

lemma index_map_mat[simp]:
  "i < dim_row A  j < dim_col A  map_mat f A $$ (i,j) = f (A $$ (i,j))"
  "dim_row (map_mat f A) = dim_row A" "dim_col (map_mat f A) = dim_col A"
  unfolding map_mat_def by auto

lemma index_smult_mat[simp]:
  "i < dim_row A  j < dim_col A  (a m A) $$ (i,j) = a * A $$ (i,j)"
  "dim_row (a m A) = dim_row A" "dim_col (a m A) = dim_col A"
  unfolding smult_mat_def by auto

lemma index_uminus_mat[simp]:
  "i < dim_row A  j < dim_col A  (- A) $$ (i,j) = - (A $$ (i,j))"
  "dim_row (- A) = dim_row A" "dim_col (- A) = dim_col A"
  unfolding uminus_mat_def by auto

lemma index_transpose_mat[simp]:
  "i < dim_col A  j < dim_row A  transpose_mat A $$ (i,j) = A $$ (j,i)"
  "dim_row (transpose_mat A) = dim_col A" "dim_col (transpose_mat A) = dim_row A"
  unfolding transpose_mat_def by auto

lemma index_mult_mat[simp]:
  "i < dim_row A  j < dim_col B  (A * B) $$ (i,j) = row A i  col B j"
  "dim_row (A * B) = dim_row A" "dim_col (A * B) = dim_col B"
  by (auto simp: times_mat_def)

lemma dim_mult_mat_vec[simp]: "dim_vec (A *v v) = dim_row A"
  by (auto simp: mult_mat_vec_def)

lemma index_mult_mat_vec[simp]: "i < dim_row A  (A *v v) $ i = row A i  v"
  by (auto simp: mult_mat_vec_def)

lemma index_row[simp]:
  "i < dim_row A  j < dim_col A  row A i $ j = A $$ (i,j)"
  "dim_vec (row A i) = dim_col A"
  by (auto simp: row_def)

lemma index_col[simp]: "i < dim_row A  j < dim_col A  col A j $ i = A $$ (i,j)"
  by (auto simp: col_def)

lemma upper_triangular_one[simp]: "upper_triangular (1m n)"
  by (rule, auto)

lemma upper_triangular_zero[simp]: "upper_triangular (0m n n)"
  by (rule, auto)

lemma mat_row_carrierI[intro,simp]: "matr nr nc r  carrier_mat nr nc"
  by (unfold carrier_mat_def carrier_vec_def, auto)

lemma eq_rowI: assumes rows: " i. i < dim_row B  row A i = row B i"
  and dims: "dim_row A = dim_row B" "dim_col A = dim_col B"
  shows "A = B"
proof (rule eq_matI[OF _ dims])
  fix i j
  assume i: "i < dim_row B" and j: "j < dim_col B"
  from rows[OF i] have id: "row A i $ j = row B i $ j" by simp
  show "A $$ (i, j) = B $$ (i, j)"
    using index_row(1)[OF i j, folded id] index_row(1)[of i A j] i j dims
    by auto
qed

lemma row_mat[simp]: "i < nr  row (mat nr nc f) i = vec nc (λ j. f (i,j))"
  by auto

lemma col_mat[simp]: "j < nc  col (mat nr nc f) j = vec nr (λ i. f (i,j))"
  by auto

lemma zero_carrier_mat[simp]: "0m nr nc  carrier_mat nr nc"
  unfolding carrier_mat_def by auto

lemma smult_carrier_mat[simp]:
  "A  carrier_mat nr nc  k m A  carrier_mat nr nc"
  unfolding carrier_mat_def by auto

lemma add_carrier_mat[simp]:
  "B  carrier_mat nr nc  A + B  carrier_mat nr nc"
  unfolding carrier_mat_def by force

lemma one_carrier_mat[simp]: "1m n  carrier_mat n n"
  unfolding carrier_mat_def by auto

lemma uminus_carrier_mat:
  "A  carrier_mat nr nc  (- A  carrier_mat nr nc)"
  unfolding carrier_mat_def by auto

lemma uminus_carrier_iff_mat[simp]:
  "(- A  carrier_mat nr nc) = (A  carrier_mat nr nc)"
  unfolding carrier_mat_def by auto

lemma minus_carrier_mat:
  "B  carrier_mat nr nc  (A - B  carrier_mat nr nc)"
  unfolding carrier_mat_def by auto

lemma transpose_carrier_mat[simp]: "(transpose_mat A  carrier_mat nc nr) = (A  carrier_mat nr nc)"
  unfolding carrier_mat_def by auto

lemma row_carrier_vec[simp]: "i < nr  A  carrier_mat nr nc  row A i  carrier_vec nc"
  unfolding carrier_vec_def by auto

lemma col_carrier_vec[simp]: "j < nc  A  carrier_mat nr nc  col A j  carrier_vec nr"
  unfolding carrier_vec_def by auto

lemma mult_carrier_mat[simp]:
  "A  carrier_mat nr n  B  carrier_mat n nc  A * B  carrier_mat nr nc"
  unfolding carrier_mat_def by auto

lemma mult_mat_vec_carrier[simp]:
  "A  carrier_mat nr n  v  carrier_vec n  A *v v  carrier_vec nr"
  unfolding carrier_mat_def carrier_vec_def by auto


lemma comm_add_mat[ac_simps]:
  "(A :: 'a :: comm_monoid_add mat)  carrier_mat nr nc  B  carrier_mat nr nc  A + B = B + A"
  by (intro eq_matI, auto simp: ac_simps)


lemma minus_r_inv_mat[simp]:
  "(A :: 'a :: group_add mat)  carrier_mat nr nc  (A - A) = 0m nr nc"
  by (intro eq_matI, auto)

lemma uminus_l_inv_mat[simp]:
  "(A :: 'a :: group_add mat)  carrier_mat nr nc  (- A + A) = 0m nr nc"
  by (intro eq_matI, auto)

lemma add_inv_exists_mat:
  "(A :: 'a :: group_add mat)  carrier_mat nr nc   B  carrier_mat nr nc. B + A = 0m nr nc  A + B = 0m nr nc"
  by (intro bexI[of _ "- A"], auto)

lemma assoc_add_mat[simp]:
  "(A :: 'a :: monoid_add mat)  carrier_mat nr nc  B  carrier_mat nr nc  C  carrier_mat nr nc
   (A + B) + C = A + (B + C)"
  by (intro eq_matI, auto simp: ac_simps)

lemma uminus_add_mat: fixes A :: "'a :: group_add mat"
  assumes "A  carrier_mat nr nc"
  and "B  carrier_mat nr nc"
  shows "- (A + B) = - B + - A"
  by (intro eq_matI, insert assms, auto simp: minus_add)

lemma transpose_transpose[simp]:
  "transpose_mat (transpose_mat A) = A"
  by (intro eq_matI, auto)

lemma transpose_one[simp]: "transpose_mat (1m n) = (1m n)"
  by auto

lemma row_transpose[simp]:
  "j < dim_col A  row (transpose_mat A) j = col A j"
  unfolding row_def col_def
  by (intro eq_vecI, auto)

lemma col_transpose[simp]:
  "i < dim_row A  col (transpose_mat A) i = row A i"
  unfolding row_def col_def
  by (intro eq_vecI, auto)

lemma row_zero[simp]:
  "i < nr  row (0m nr nc) i = 0v nc"
   by (intro eq_vecI, auto)

lemma col_zero[simp]:
  "j < nc  col (0m nr nc) j = 0v nr"
   by (intro eq_vecI, auto)

lemma row_one[simp]:
  "i < n  row (1m n) i = unit_vec n i"
  by (intro eq_vecI, auto)

lemma col_one[simp]:
  "j < n  col (1m n) j = unit_vec n j"
  by (intro eq_vecI, auto)

lemma transpose_add: "A  carrier_mat nr nc  B  carrier_mat nr nc
   transpose_mat (A + B) = transpose_mat A + transpose_mat B"
  by (intro eq_matI, auto)

lemma transpose_minus: "A  carrier_mat nr nc  B  carrier_mat nr nc
   transpose_mat (A - B) = transpose_mat A - transpose_mat B"
  by (intro eq_matI, auto)

lemma transpose_uminus: "A  carrier_mat nr nc  transpose_mat (- A) = - (transpose_mat A)"
  by (intro eq_matI, auto)

lemma row_add[simp]:
  "A  carrier_mat nr nc  B  carrier_mat nr nc  i < nr
   row (A + B) i = row A i + row B i"
  "i < dim_row A  dim_row B = dim_row A  dim_col B = dim_col A  row (A + B) i = row A i + row B i"
  by (rule eq_vecI, auto)

lemma col_add[simp]:
  "A  carrier_mat nr nc  B  carrier_mat nr nc  j < nc
   col (A + B) j = col A j + col B j"
  by (rule eq_vecI, auto)

lemma row_mult[simp]: assumes m: "A  carrier_mat nr n" "B  carrier_mat n nc"
  and i: "i < nr"
  shows "row (A * B) i = vec nc (λ j. row A i  col B j)"
  by (rule eq_vecI, insert m i, auto)

lemma col_mult[simp]: assumes m: "A  carrier_mat nr n" "B  carrier_mat n nc"
  and j: "j < nc"
  shows "col (A * B) j = vec nr (λ i. row A i  col B j)"
  by (rule eq_vecI, insert m j, auto)

lemma transpose_mult:
  "(A :: 'a :: comm_semiring_0 mat)  carrier_mat nr n  B  carrier_mat n nc
   transpose_mat (A * B) = transpose_mat B * transpose_mat A"
  by (intro eq_matI, auto simp: comm_scalar_prod[of _ n])

lemma left_add_zero_mat[simp]:
  "(A :: 'a :: monoid_add mat)  carrier_mat nr nc   0m nr nc + A = A"
  by (intro eq_matI, auto)

lemma add_uminus_minus_mat: "A  carrier_mat nr nc  B  carrier_mat nr nc  
  A + (- B) = A - (B :: 'a :: group_add mat)" 
  by (intro eq_matI, auto)

lemma right_add_zero_mat[simp]: "A  carrier_mat nr nc  
  A + 0m nr nc = (A :: 'a :: monoid_add mat)" 
  by (intro eq_matI, auto)

lemma left_mult_zero_mat:
  "A  carrier_mat n nc  0m nr n * A = 0m nr nc"
  by (intro eq_matI, auto)

lemma left_mult_zero_mat'[simp]: "dim_row A = n  0m nr n * A = 0m nr (dim_col A)"
  by (rule left_mult_zero_mat, unfold carrier_mat_def, simp)

lemma right_mult_zero_mat:
  "A  carrier_mat nr n  A * 0m n nc = 0m nr nc"
  by (intro eq_matI, auto)

lemma right_mult_zero_mat'[simp]: "dim_col A = n  A * 0m n nc = 0m (dim_row A) nc"
  by (rule right_mult_zero_mat, unfold carrier_mat_def, simp)

lemma left_mult_one_mat:
  "(A :: 'a :: semiring_1 mat)  carrier_mat nr nc  1m nr * A = A"
  by (intro eq_matI, auto)

lemma left_mult_one_mat'[simp]: "dim_row (A :: 'a :: semiring_1 mat) = n  1m n * A = A"
  by (rule left_mult_one_mat, unfold carrier_mat_def, simp)

lemma right_mult_one_mat:
  "(A :: 'a :: semiring_1 mat)  carrier_mat nr nc  A * 1m nc = A"
  by (intro eq_matI, auto)

lemma right_mult_one_mat'[simp]: "dim_col (A :: 'a :: semiring_1 mat) = n  A * 1m n = A"
  by (rule right_mult_one_mat, unfold carrier_mat_def, simp)

lemma one_mult_mat_vec[simp]:
  "(v :: 'a :: semiring_1 vec)  carrier_vec n  1m n *v v = v"
  by (intro eq_vecI, auto)

lemma minus_add_uminus_mat: fixes A :: "'a :: group_add mat"
  shows "A  carrier_mat nr nc  B  carrier_mat nr nc 
  A - B = A + (- B)"
  by (intro eq_matI, auto)

lemma add_mult_distrib_mat[algebra_simps]: assumes m: "A  carrier_mat nr n"
  "B  carrier_mat nr n" "C  carrier_mat n nc"
  shows "(A + B) * C = A * C + B * C"
  using m by (intro eq_matI, auto simp: add_scalar_prod_distrib[of _ n])

lemma mult_add_distrib_mat[algebra_simps]: assumes m: "A  carrier_mat nr n"
  "B  carrier_mat n nc" "C  carrier_mat n nc"
  shows "A * (B + C) = A * B + A * C"
  using m by (intro eq_matI, auto simp: scalar_prod_add_distrib[of _ n])

lemma add_mult_distrib_mat_vec[algebra_simps]: assumes m: "A  carrier_mat nr nc"
  "B  carrier_mat nr nc" "v  carrier_vec nc"
  shows "(A + B) *v v = A *v v + B *v v"
  using m by (intro eq_vecI, auto intro!: add_scalar_prod_distrib)

lemma mult_add_distrib_mat_vec[algebra_simps]: assumes m: "A  carrier_mat nr nc"
  "v1  carrier_vec nc" "v2  carrier_vec nc"
  shows "A *v (v1 + v2) = A *v v1 + A *v v2"
  using m by (intro eq_vecI, auto simp: scalar_prod_add_distrib[of _ nc])

lemma mult_mat_vec:
  assumes m: "(A::'a::field mat)  carrier_mat nr nc" and v: "v  carrier_vec nc"
  shows "A *v (k v v) = k v (A *v v)" (is "?l = ?r")
proof
  have nr: "dim_vec ?l = nr" using m v by auto
  also have "... = dim_vec ?r" using m v by auto
  finally show "dim_vec ?l = dim_vec ?r".

  show "i. i < dim_vec ?r  ?l $ i = ?r $ i"
  proof -
    fix i assume "i < dim_vec ?r"
    hence i: "i < dim_row A" using nr m by auto
    hence i2: "i < dim_vec (A *v v)" using m by auto
    show "?l $ i = ?r $ i"
    apply (subst (1) mult_mat_vec_def)
    apply (subst (2) smult_vec_def)
    unfolding index_vec[OF i] index_vec[OF i2]
    unfolding mult_mat_vec_def smult_vec_def
    unfolding scalar_prod_def index_vec[OF i]
    by (simp add: mult.left_commute sum_distrib_left)
  qed
qed

lemma assoc_scalar_prod: assumes *: "v1  carrier_vec nr" "A  carrier_mat nr nc" "v2  carrier_vec nc"
  shows "vec nc (λj. v1  col A j)  v2 = v1  vec nr (λi. row A i  v2)"
proof -
  have "vec nc (λj. v1  col A j)  v2 = (i{0..<nc}. vec nc (λj. k{0..<nr}. v1 $ k * col A j $ k) $ i * v2 $ i)"
    unfolding scalar_prod_def using * by auto
  also have " = (i{0..<nc}. (k{0..<nr}. v1 $ k * col A i $ k) * v2 $ i)"
    by (rule sum.cong, auto)
  also have " = (i{0..<nc}. (k{0..<nr}. v1 $ k * col A i $ k * v2 $ i))"
    unfolding sum_distrib_right ..
  also have " = (k{0..<nr}. (i{0..<nc}. v1 $ k * col A i $ k * v2 $ i))"
    by (rule sum.swap)
  also have " = (k{0..<nr}. (i{0..<nc}. v1 $ k * (col A i $ k * v2 $ i)))"
    by (simp add: ac_simps)
  also have " = (k{0..<nr}. v1 $ k * (i{0..<nc}. col A i $ k * v2 $ i))"
    unfolding sum_distrib_left ..
  also have " = (k{0..<nr}. v1 $ k * vec nr (λk. i{0..<nc}. row A k $ i * v2 $ i) $ k)"
    using * by auto
  also have " = v1  vec nr (λi. row A i  v2)" unfolding scalar_prod_def using * by simp
  finally show ?thesis .
qed

lemma assoc_mult_mat[simp]:
  "A  carrier_mat n1 n2  B  carrier_mat n2 n3  C  carrier_mat n3 n4
   (A * B) * C = A * (B * C)"
  by (intro eq_matI, auto simp: assoc_scalar_prod)

lemma assoc_mult_mat_vec[simp]:
  "A  carrier_mat n1 n2  B  carrier_mat n2 n3  v  carrier_vec n3
   (A * B) *v v = A *v (B *v v)"
  by (intro eq_vecI, auto simp add: mult_mat_vec_def assoc_scalar_prod)

lemma comm_monoid_mat: "comm_monoid (monoid_mat TYPE('a :: comm_monoid_add) nr nc)"
  by (unfold_locales, auto simp: monoid_mat_def ac_simps)

lemma comm_group_mat: "comm_group (monoid_mat TYPE('a :: ab_group_add) nr nc)"
  by (unfold_locales, insert add_inv_exists_mat, auto simp: monoid_mat_def ac_simps Units_def)

lemma semiring_mat: "semiring (ring_mat TYPE('a :: semiring_1) n b)"
  by (unfold_locales, auto simp: ring_mat_def algebra_simps)

lemma ring_mat: "ring (ring_mat TYPE('a :: comm_ring_1) n b)"
  by (unfold_locales, insert add_inv_exists_mat, auto simp: ring_mat_def algebra_simps Units_def)

lemma abelian_group_mat: "abelian_group (module_mat TYPE('a :: comm_ring_1) nr nc)"
  by (unfold_locales, insert add_inv_exists_mat, auto simp: module_mat_def Units_def)

lemma row_smult[simp]: assumes i: "i < dim_row A"
  shows "row (k m A) i = k v (row A i)"
  by (rule eq_vecI, insert i, auto)

lemma col_smult[simp]: assumes i: "i < dim_col A"
  shows "col (k m A) i = k v (col A i)"
  by (rule eq_vecI, insert i, auto)

lemma row_uminus[simp]: assumes i: "i < dim_row A"
  shows "row (- A) i = - (row A i)"
  by (rule eq_vecI, insert i, auto)

lemma scalar_prod_uminus_left[simp]: assumes dim: "dim_vec v = dim_vec (w :: 'a :: ring vec)"
  shows "- v  w = - (v  w)"
  unfolding scalar_prod_def dim[symmetric]
  by (subst sum_negf[symmetric], rule sum.cong, auto)

lemma col_uminus[simp]: assumes i: "i < dim_col A"
  shows "col (- A) i = - (col A i)"
  by (rule eq_vecI, insert i, auto)

lemma scalar_prod_uminus_right[simp]: assumes dim: "dim_vec v = dim_vec (w :: 'a :: ring vec)"
  shows "v  - w = - (v  w)"
  unfolding scalar_prod_def dim
  by (subst sum_negf[symmetric], rule sum.cong, auto)

context fixes A B :: "'a :: ring mat"
  assumes dim: "dim_col A = dim_row B"
begin
lemma uminus_mult_left_mat[simp]: "(- A * B) = - (A * B)"
  by (intro eq_matI, insert dim, auto)

lemma uminus_mult_right_mat[simp]: "(A * - B) = - (A * B)"
  by (intro eq_matI, insert dim, auto)
end

lemma minus_mult_distrib_mat[algebra_simps]: fixes A :: "'a :: ring mat"
  assumes m: "A  carrier_mat nr n" "B  carrier_mat nr n" "C  carrier_mat n nc"
  shows "(A - B) * C = A * C - B * C"
  unfolding minus_add_uminus_mat[OF m(1,2)]
    add_mult_distrib_mat[OF m(1) uminus_carrier_mat[OF m(2)] m(3)]
  by (subst uminus_mult_left_mat, insert m, auto)

lemma minus_mult_distrib_mat_vec[algebra_simps]: assumes A: "(A :: 'a :: ring mat)  carrier_mat nr nc"
  and B: "B  carrier_mat nr nc"
  and v: "v  carrier_vec nc"
shows "(A - B) *v v = A *v v - B *v v"
  unfolding minus_add_uminus_mat[OF A B]
  by (subst add_mult_distrib_mat_vec[OF A _ v], insert A B v, auto)

lemma mult_minus_distrib_mat_vec[algebra_simps]: assumes A: "(A :: 'a :: ring mat)  carrier_mat nr nc"
  and v: "v  carrier_vec nc"
  and w: "w  carrier_vec nc"
shows "A *v (v - w) = A *v v - A *v w"
  unfolding minus_add_uminus_vec[OF v w]
  by (subst mult_add_distrib_mat_vec[OF A], insert A v w, auto)

lemma mult_minus_distrib_mat[algebra_simps]: fixes A :: "'a :: ring mat"
  assumes m: "A  carrier_mat nr n" "B  carrier_mat n nc" "C  carrier_mat n nc"
  shows "A * (B - C) = A * B - A * C"
  unfolding minus_add_uminus_mat[OF m(2,3)]
    mult_add_distrib_mat[OF m(1) m(2) uminus_carrier_mat[OF m(3)]]
  by (subst uminus_mult_right_mat, insert m, auto)

lemma uminus_mult_mat_vec[simp]: assumes v: "dim_vec v = dim_col (A :: 'a :: ring mat)"
  shows "- A *v v = - (A *v v)"
  using v by (intro eq_vecI, auto)

lemma uminus_zero_vec_eq: assumes v: "(v :: 'a :: group_add vec)  carrier_vec n"
  shows "(- v = 0v n) = (v = 0v n)"
proof
  assume z: "- v = 0v n"
  {
    fix i
    assume i: "i < n"
    have "v $ i = - (- (v $ i))" by simp
    also have "- (v $ i) = 0" using arg_cong[OF z, of "λ v. v $ i"] i v by auto
    also have "- 0 = (0 :: 'a)" by simp
    finally have "v $ i = 0" .
  }
  thus "v = 0v n" using v
    by (intro eq_vecI, auto)
qed auto

lemma map_carrier_mat[simp]:
  "(map_mat f A  carrier_mat nr nc) = (A  carrier_mat nr nc)"
  unfolding carrier_mat_def by auto

lemma col_map_mat[simp]:
  assumes "j < dim_col A" shows "col (map_mat f A) j = map_vec f (col A j)"
  unfolding map_mat_def map_vec_def using assms by auto

lemma scalar_vec_one[simp]: "1 v (v :: 'a :: semiring_1 vec) = v"
  by (rule eq_vecI, auto)

lemma scalar_prod_smult_right[simp]:
  "dim_vec w = dim_vec v  w  (k v v) = (k :: 'a :: comm_semiring_0) * (w  v)"
  unfolding scalar_prod_def sum_distrib_left
  by (auto intro: sum.cong simp: ac_simps)

lemma scalar_prod_smult_left[simp]:
  "dim_vec w = dim_vec v  (k v w)  v = (k :: 'a :: comm_semiring_0) * (w  v)"
  unfolding scalar_prod_def sum_distrib_left
  by (auto intro: sum.cong simp: ac_simps)

lemma mult_smult_distrib: assumes A: "A  carrier_mat nr n" and B: "B  carrier_mat n nc"
  shows "A * (k m B) = (k :: 'a :: comm_semiring_0) m (A * B)"
  by (rule eq_matI, insert A B, auto)

lemma add_smult_distrib_left_mat: assumes "A  carrier_mat nr nc" "B  carrier_mat nr nc"
  shows "k m (A + B) = (k :: 'a :: semiring) m A + k m B"
  by (rule eq_matI, insert assms, auto simp: field_simps)

lemma add_smult_distrib_right_mat: assumes "A  carrier_mat nr nc"
  shows "(k + l) m A = (k :: 'a :: semiring) m A + l m A"
  by (rule eq_matI, insert assms, auto simp: field_simps)

lemma mult_smult_assoc_mat: assumes A: "A  carrier_mat nr n" and B: "B  carrier_mat n nc"
  shows "(k m A) * B = (k :: 'a :: comm_semiring_0) m (A * B)"
  by (rule eq_matI, insert A B, auto)

definition similar_mat_wit :: "'a :: semiring_1 mat  'a mat  'a mat  'a mat  bool" where
  "similar_mat_wit A B P Q = (let n = dim_row A in {A,B,P,Q}  carrier_mat n n  P * Q = 1m n  Q * P = 1m n 
    A = P * B * Q)"

definition similar_mat :: "'a :: semiring_1 mat  'a mat  bool" where
  "similar_mat A B = ( P Q. similar_mat_wit A B P Q)"

lemma similar_matD: assumes "similar_mat A B"
  shows " n P Q. {A,B,P,Q}  carrier_mat n n  P * Q = 1m n  Q * P = 1m n  A = P * B * Q"
  using assms unfolding similar_mat_def similar_mat_wit_def[abs_def] Let_def by blast

lemma similar_matI: assumes "{A,B,P,Q}  carrier_mat n n" "P * Q = 1m n" "Q * P = 1m n" "A = P * B * Q"
  shows "similar_mat A B" unfolding similar_mat_def
  by (rule exI[of _ P], rule exI[of _ Q], unfold similar_mat_wit_def Let_def, insert assms, auto)

fun pow_mat :: "'a :: semiring_1 mat  nat  'a mat" (infixr "^m" 75) where
  "A ^m 0 = 1m (dim_row A)"
| "A ^m (Suc k) = A ^m k * A"

lemma pow_mat_dim[simp]:
  "dim_row (A ^m k) = dim_row A"
  "dim_col (A ^m k) = (if k = 0 then dim_row A else dim_col A)"
  by (induct k, auto)

lemma pow_mat_dim_square[simp]:
  "A  carrier_mat n n  dim_row (A ^m k) = n"
  "A  carrier_mat n n  dim_col (A ^m k) = n"
  by auto

lemma pow_carrier_mat[simp]: "A  carrier_mat n n  A ^m k  carrier_mat n n"
  unfolding carrier_mat_def by auto

definition diag_mat :: "'a mat  'a list" where
  "diag_mat A = map (λ i. A $$ (i,i)) [0 ..< dim_row A]"

lemma prod_list_diag_prod: "prod_list (diag_mat A) = ( i = 0 ..< dim_row A. A $$ (i,i))"
  unfolding diag_mat_def
  by (subst prod.distinct_set_conv_list[symmetric], auto)

lemma diag_mat_transpose[simp]: "dim_row A = dim_col A 
  diag_mat (transpose_mat A) = diag_mat A" unfolding diag_mat_def by auto

lemma diag_mat_zero[simp]: "diag_mat (0m n n) = replicate n 0"
  unfolding diag_mat_def
  by (rule nth_equalityI, auto)

lemma diag_mat_one[simp]: "diag_mat (1m n) = replicate n 1"
  unfolding diag_mat_def
  by (rule nth_equalityI, auto)

lemma pow_mat_ring_pow: assumes A: "(A :: ('a :: semiring_1)mat)  carrier_mat n n"
  shows "A ^m k = A [^]ring_mat TYPE('a) n b k"
  (is "_ = A [^]?C k")
proof -
  interpret semiring ?C by (rule semiring_mat)
  show ?thesis
    by (induct k, insert A, auto simp: ring_mat_def nat_pow_def)
qed

definition diagonal_mat :: "'a::zero mat  bool" where
  "diagonal_mat A  i<dim_row A. j<dim_col A. i  j  A $$ (i,j) = 0"

definition (in comm_monoid_add) sum_mat :: "'a mat  'a" where
  "sum_mat A = sum (λ ij. A $$ ij) ({0 ..< dim_row A} × {0 ..< dim_col A})"

lemma sum_mat_0[simp]: "sum_mat (0m nr nc) = (0 :: 'a :: comm_monoid_add)"
  unfolding sum_mat_def
  by (rule sum.neutral, auto)

lemma sum_mat_add: assumes A: "(A :: 'a :: comm_monoid_add mat)  carrier_mat nr nc" and B: "B  carrier_mat nr nc"
  shows "sum_mat (A + B) = sum_mat A + sum_mat B"
proof -
  from A B have id: "dim_row A = nr" "dim_row B = nr" "dim_col A = nc" "dim_col B = nc"
    by auto
  show ?thesis unfolding sum_mat_def id
    by (subst sum.distrib[symmetric], rule sum.cong, insert A B, auto)
qed

subsection ‹Update Operators›

definition update_vec :: "'a vec  nat  'a  'a vec" ("_ |v _  _" [60,61,62] 60)
  where "v |v i  a = vec (dim_vec v) (λi'. if i' = i then a else v $ i')"

definition update_mat :: "'a mat  nat × nat  'a  'a mat" ("_ |m _  _" [60,61,62] 60)
  where "A |m ij  a = mat (dim_row A) (dim_col A) (λij'. if ij' = ij then a else A $$ ij')"

lemma dim_update_vec[simp]:
  "dim_vec (v |v i  a) = dim_vec v" unfolding update_vec_def by simp

lemma index_update_vec1[simp]:
  assumes "i < dim_vec v" shows "(v |v i  a) $ i = a"
  unfolding update_vec_def using assms by simp

lemma index_update_vec2[simp]:
  assumes "i'  i" shows "(v |v i  a) $ i' = v $ i'"
  unfolding update_vec_def
  using assms apply transfer unfolding mk_vec_def by auto

lemma dim_update_mat[simp]:
  "dim_row (A |m ij  a) = dim_row A"
  "dim_col (A |m ij  a) = dim_col A" unfolding update_mat_def by simp+

lemma index_update_mat1[simp]:
  assumes "i < dim_row A" "j < dim_col A" shows "(A |m (i,j)  a) $$ (i,j) = a"
  unfolding update_mat_def using assms by simp

lemma index_update_mat2[simp]:
  assumes i': "i' < dim_row A" and j': "j' < dim_col A" and neq: "(i',j')  ij"
  shows "(A |m ij  a) $$ (i',j') = A $$ (i',j')"
  unfolding update_mat_def using assms by auto

subsection ‹Block Vectors and Matrices›

definition append_vec :: "'a vec  'a vec  'a vec" (infixr "@v" 65) where
  "v @v w  let n = dim_vec v; m = dim_vec w in
    vec (n + m) (λ i. if i < n then v $ i else w $ (i - n))"

lemma index_append_vec[simp]: "i < dim_vec v + dim_vec w
   (v @v w) $ i = (if i < dim_vec v then v $ i else w $ (i - dim_vec v))"
  "dim_vec (v @v w) = dim_vec v + dim_vec w"
  unfolding append_vec_def Let_def by auto

lemma append_carrier_vec[simp,intro]:
  "v  carrier_vec n1  w  carrier_vec n2  v @v w  carrier_vec (n1 + n2)"
  unfolding carrier_vec_def by auto

lemma scalar_prod_append: assumes "v1  carrier_vec n1" "v2  carrier_vec n2"
  "w1  carrier_vec n1" "w2  carrier_vec n2"
  shows "(v1 @v v2)  (w1 @v w2) = v1  w1 + v2  w2"
proof -
  from assms have dim: "dim_vec v1 = n1" "dim_vec v2 = n2" "dim_vec w1 = n1" "dim_vec w2 = n2" by auto
  have id: "{0 ..< n1 + n2} = {0 ..< n1}  {n1 ..< n1 + n2}" by auto
  have id2: "{n1 ..< n1 + n2} = (plus n1) ` {0 ..< n2}"
    by (simp add: ac_simps)
  have "(v1 @v v2)  (w1 @v w2) = (i = 0..<n1. v1 $ i * w1 $ i) +
    (i = n1..<n1 + n2. v2 $ (i - n1) * w2 $ (i - n1))"
  unfolding scalar_prod_def
    by (auto simp: dim id, subst sum.union_disjoint, insert assms, force+)
  also have "(i = n1..<n1 + n2. v2 $ (i - n1) * w2 $ (i - n1))
    = (i = 0..< n2. v2 $ i * w2 $ i)"
    by (rule sum.reindex_cong [OF _ id2]) simp_all
  finally show ?thesis by (simp, insert assms, auto simp: scalar_prod_def)
qed

definition "vec_first v n  vec n (λi. v $ i)"
definition "vec_last v n  vec n (λi. v $ (dim_vec v - n + i))"

lemma dim_vec_first[simp]: "dim_vec (vec_first v n) = n" unfolding vec_first_def by auto
lemma dim_vec_last[simp]: "dim_vec (vec_last v n) = n" unfolding vec_last_def by auto

lemma vec_first_carrier[simp]: "vec_first v n  carrier_vec n" by (rule carrier_vecI, auto)
lemma vec_last_carrier[simp]: "vec_last v n  carrier_vec n" by (rule carrier_vecI, auto)

lemma vec_first_last_append[simp]:
  assumes "v  carrier_vec (n+m)" shows "vec_first v n @v vec_last v m = v"
  apply(rule) unfolding vec_first_def vec_last_def using assms by auto

lemma append_vec_le: assumes "v  carrier_vec n" and w: "w  carrier_vec n" 
  shows "v @v v'  w @v w'  v  w  v'  w'" 
proof -
  {
    fix i
    assume *: "i. (¬ i < n  i < n + dim_vec w'  v' $ (i - n)  w' $ (i - n))"
      and i: "i < dim_vec w'" 
    have "v' $ i  w' $ i" using *[rule_format, of "n + i"] i by auto
  }
  thus ?thesis using assms unfolding less_eq_vec_def by auto
qed

lemma all_vec_append: "( x  carrier_vec (n + m). P x)  ( x1  carrier_vec n.  x2  carrier_vec m. P (x1 @v x2))" 
proof (standard, force, intro ballI, goal_cases)
  case (1 x)
  have "x = vec n (λ i. x $ i) @v vec m (λ i. x $ (n + i))" 
    by (rule eq_vecI, insert 1(2), auto)
  hence "P x = P (vec n (λ i. x $ i) @v vec m (λ i. x $ (n + i)))" by simp
  also have "" using 1 by auto
  finally show ?case .
qed


(* A B
   C D *)
definition four_block_mat :: "'a mat  'a mat  'a mat  'a mat  'a mat" where
  "four_block_mat A B C D =
    (let nra = dim_row A; nrd = dim_row D;
         nca = dim_col A; ncd = dim_col D
       in
    mat (nra + nrd) (nca + ncd) (λ (i,j). if i < nra then
      if j < nca then A $$ (i,j) else B $$ (i,j - nca)
      else if j < nca then C $$ (i - nra, j) else D $$ (i - nra, j - nca)))"

lemma index_mat_four_block[simp]:
  "i < dim_row A + dim_row D  j < dim_col A + dim_col D  four_block_mat A B C D $$ (i,j)
  = (if i < dim_row A then
      if j < dim_col A then A $$ (i,j) else B $$ (i,j - dim_col A)
      else if j < dim_col A then C $$ (i - dim_row A, j) else D $$ (i - dim_row A, j - dim_col A))"
  "dim_row (four_block_mat A B C D) = dim_row A + dim_row D"
  "dim_col (four_block_mat A B C D) = dim_col A + dim_col D"
  unfolding four_block_mat_def Let_def by auto

lemma four_block_carrier_mat[simp]:
  "A  carrier_mat nr1 nc1  D  carrier_mat nr2 nc2 
  four_block_mat A B C D  carrier_mat (nr1 + nr2) (nc1 + nc2)"
  unfolding carrier_mat_def by auto

lemma cong_four_block_mat: "A1 = B1  A2 = B2  A3 = B3  A4 = B4 
  four_block_mat A1 A2 A3 A4 = four_block_mat B1 B2 B3 B4" by auto

lemma four_block_one_mat[simp]:
  "four_block_mat (1m n1) (0m n1 n2) (0m n2 n1) (1m n2) = 1m (n1 + n2)"
  by (rule eq_matI, auto)

lemma four_block_zero_mat[simp]:
  "four_block_mat (0m nr1 nc1) (0m nr1 nc2) (0m nr2 nc1) (0m nr2 nc2) = 0m (nr1 + nr2) (nc1 + nc2)"
  by (rule eq_matI, auto)

lemma row_four_block_mat:
  assumes c: "A  carrier_mat nr1 nc1" "B  carrier_mat nr1 nc2"
  "C  carrier_mat nr2 nc1" "D  carrier_mat nr2 nc2"
  shows
  "i < nr1  row (four_block_mat A B C D) i = row A i @v row B i" (is "_  ?AB")
  "¬ i < nr1  i < nr1 + nr2  row (four_block_mat A B C D) i = row C (i - nr1) @v row D (i - nr1)"
  (is "_  _  ?CD")
proof -
  assume i: "i < nr1"
  show ?AB by (rule eq_vecI, insert i c, auto)
next
  assume i: "¬ i < nr1" "i < nr1 + nr2"
  show ?CD by (rule eq_vecI, insert i c, auto)
qed

lemma col_four_block_mat:
  assumes c: "A  carrier_mat nr1 nc1" "B  carrier_mat nr1 nc2"
  "C  carrier_mat nr2 nc1" "D  carrier_mat nr2 nc2"
  shows
  "j < nc1  col (four_block_mat A B C D) j = col A j @v col C j" (is "_  ?AC")
  "¬ j < nc1  j < nc1 + nc2  col (four_block_mat A B C D) j = col B (j - nc1) @v col D (j - nc1)"
  (is "_  _  ?BD")
proof -
  assume j: "j < nc1"
  show ?AC by (rule eq_vecI, insert j c, auto)
next
  assume j: "¬ j < nc1" "j < nc1 + nc2"
  show ?BD by (rule eq_vecI, insert j c, auto)
qed

lemma mult_four_block_mat: assumes
  c1: "A1  carrier_mat nr1 n1" "B1  carrier_mat nr1 n2" "C1  carrier_mat nr2 n1" "D1  carrier_mat nr2 n2" and
  c2: "A2  carrier_mat n1 nc1" "B2  carrier_mat n1 nc2" "C2  carrier_mat n2 nc1" "D2  carrier_mat n2 nc2"
  shows "four_block_mat A1 B1 C1 D1 * four_block_mat A2 B2 C2 D2
  = four_block_mat (A1 * A2 + B1 * C2) (A1 * B2 + B1 * D2)
    (C1 * A2 + D1 * C2) (C1 * B2 + D1 * D2)" (is "?M1 * ?M2 = _")
proof -
  note row = row_four_block_mat[OF c1]
  note col = col_four_block_mat[OF c2]
  {
    fix i j
    assume i: "i < nr1" and j: "j < nc1"
    have "row ?M1 i  col ?M2 j = row A1 i  col A2 j + row B1 i  col C2 j"
      unfolding row(1)[OF i] col(1)[OF j]
      by (rule scalar_prod_append[of _ n1 _ n2], insert c1 c2 i j, auto)
  }
  moreover
  {
    fix i j
    assume i: "¬ i < nr1" "i < nr1 + nr2" and j: "j < nc1"
    hence i': "i - nr1 < nr2" by auto
    have "row ?M1 i  col ?M2 j = row C1 (i - nr1)  col A2 j + row D1 (i - nr1)  col C2 j"
      unfolding row(2)[OF i] col(1)[OF j]
      by (rule scalar_prod_append[of _ n1 _ n2], insert c1 c2 i i' j, auto)
  }
  moreover
  {
    fix i j
    assume i: "i < nr1" and j: "¬ j < nc1" "j < nc1 + nc2"
    hence j': "j - nc1 < nc2" by auto
    have "row ?M1 i  col ?M2 j = row A1 i  col B2 (j - nc1) + row B1 i  col D2 (j - nc1)"
      unfolding row(1)[OF i] col(2)[OF j]
      by (rule scalar_prod_append[of _ n1 _ n2], insert c1 c2 i j' j, auto)
  }
  moreover
  {
    fix i j
    assume i: "¬ i < nr1" "i < nr1 + nr2" and j: "¬ j < nc1" "j < nc1 + nc2"
    hence i': "i - nr1 < nr2" and j': "j - nc1 < nc2" by auto
    have "row ?M1 i  col ?M2 j = row C1 (i - nr1)  col B2 (j - nc1) + row D1 (i - nr1)  col D2 (j - nc1)"
      unfolding row(2)[OF i] col(2)[OF j]
      by (rule scalar_prod_append[of _ n1 _ n2], insert c1 c2 i i' j' j, auto)
  }
  ultimately show ?thesis
    by (intro eq_matI, insert c1 c2, auto)
qed

definition append_rows :: "'a :: zero mat  'a mat  'a mat" (infixr "@r" 65)where
  "A @r B = four_block_mat A (0m (dim_row A) 0) B (0m (dim_row B) 0)" 

lemma carrier_append_rows[simp,intro]: "A  carrier_mat nr1 nc  B  carrier_mat nr2 nc 
  A @r B  carrier_mat (nr1 + nr2) nc" 
  unfolding append_rows_def by auto

lemma col_mult2[simp]:
  assumes A: "A : carrier_mat nr n"
      and B: "B : carrier_mat n nc"
      and j: "j < nc"
  shows "col (A * B) j = A *v col B j"
proof
  have AB: "A * B : carrier_mat nr nc" using A B by auto
  fix i assume i: "i < dim_vec (A *v col B j)"
  show "col (A * B) j $ i = (A *v col B j) $ i"
    using A B AB j i by simp
qed auto

lemma mat_vec_as_mat_mat_mult: assumes A: "A  carrier_mat nr nc" 
  and v: "v  carrier_vec nc" 
shows "A *v v = col (A * mat_of_cols nc [v]) 0"  
  by (subst col_mult2[OF A], insert v, auto)

lemma mat_mult_append: assumes A: "A  carrier_mat nr1 nc" 
  and B: "B  carrier_mat nr2 nc" 
  and v: "v  carrier_vec nc" 
shows "(A @r B) *v v = (A *v v) @v (B *v v)" 
proof -
  let ?Fb1 = "four_block_mat A (0m nr1 0) B (0m nr2 0)" 
  let ?Fb2 = "four_block_mat (mat_of_cols nc [v]) (0m nc 0) (0m 0 1) (0m 0 0)" 
  have id: "?Fb2 = mat_of_cols nc [v]" 
    using v by auto
  have "(A @r B) *v v = col (?Fb1 * ?Fb2) 0" unfolding id
    by (subst mat_vec_as_mat_mat_mult[OF _ v], insert A B, auto simp: append_rows_def)
  also have "?Fb1 * ?Fb2 = four_block_mat (A * mat_of_cols nc [v] + 0m nr1 0 * 0m 0 1) (A * 0m nc 0 + 0m nr1 0 * 0m 0 0)
     (B * mat_of_cols nc [v] + 0m nr2 0 * 0m 0 1) (B * 0m nc 0 + 0m nr2 0 * 0m 0 0)" 
    by (rule mult_four_block_mat[OF A _ B], auto)
  also have "(A * mat_of_cols nc [v] + 0m nr1 0 * 0m 0 1) = A * mat_of_cols nc [v]" 
    using A v by auto
  also have "(B * mat_of_cols nc [v] + 0m nr2 0 * 0m 0 1) = B * mat_of_cols nc [v]" 
    using B v by auto
  also have "(A * 0m nc 0 + 0m nr1 0 * 0m 0 0) = 0m nr1 0" using A by auto 
  also have "(B * 0m nc 0 + 0m nr2 0 * 0m 0 0) = 0m nr2 0" using B by auto
  finally have "(A @r B) *v v = col (four_block_mat (A * mat_of_cols nc [v]) (0m nr1 0) (B * mat_of_cols nc [v]) (0m nr2 0)) 0" .
  also have " = col (A * mat_of_cols nc [v]) 0 @v col (B * mat_of_cols nc [v]) 0" 
    by (rule col_four_block_mat, insert A B v, auto)
  also have "col (A * mat_of_cols nc [v]) 0 = A *v v" 
    by (rule mat_vec_as_mat_mat_mult[symmetric, OF A v])
  also have "col (B * mat_of_cols nc [v]) 0 = B *v v" 
    by (rule mat_vec_as_mat_mat_mult[symmetric, OF B v])
  finally show ?thesis .
qed
 
lemma append_rows_le: assumes A: "A  carrier_mat nr1 nc" 
  and B: "B  carrier_mat nr2 nc" 
  and a: "a  carrier_vec nr1" 
  and v: "v  carrier_vec nc"
shows "(A @r B) *v v  (a @v b)  A *v v  a  B *v v  b" 
  unfolding mat_mult_append[OF A B v]
  by (rule append_vec_le[OF _ a], insert A v, auto)


lemma elements_four_block_mat:
  assumes c: "A  carrier_mat nr1 nc1" "B  carrier_mat nr1 nc2"
  "C  carrier_mat nr2 nc1" "D  carrier_mat nr2 nc2"
  shows
  "elements_mat (four_block_mat A B C D) 
   elements_mat A  elements_mat B  elements_mat C  elements_mat D"
   (is "elements_mat ?four  _")
proof rule
  fix a assume "a  elements_mat ?four"
  then obtain i j
    where i4: "i < dim_row ?four" and j4: "j < dim_col ?four" and a: "a = ?four $$ (i, j)"
    by auto
  show "a  elements_mat A  elements_mat B  elements_mat C  elements_mat D"
  proof (cases "i < nr1")
    case True note i1 = this
    show ?thesis
    proof (cases "j < nc1")
      case True
      then have "a = A $$ (i,j)" using c i1 a by simp
      thus ?thesis using c i1 True by auto next
      case False
      then have "a = B $$ (i,j-nc1)" using c i1 a j4 by simp
      moreover have "j - nc1 < nc2" using c j4 False by auto
      ultimately show ?thesis using c i1 by auto
    qed next
    case False note i1 = this
    have i2: "i - nr1 < nr2" using c i1 i4 by auto
    show ?thesis
    proof (cases "j < nc1")
      case True
      then have "a = C $$ (i-nr1,j)" using c i2 a i1 by simp
      thus ?thesis using c i2 True by auto next
      case False
      then have "a = D $$ (i-nr1,j-nc1)" using c i2 a i1 j4 by simp
      moreover have "j - nc1 < nc2" using c j4 False by auto
      ultimately show ?thesis using c i2 by auto
    qed
  qed
qed

lemma assoc_four_block_mat: fixes FB :: "'a mat  'a mat  'a :: zero mat"
  defines FB: "FB  λ Bb Cc. four_block_mat Bb (0m (dim_row Bb) (dim_col Cc)) (0m (dim_row Cc) (dim_col Bb)) Cc"
  shows "FB A (FB B C) = FB (FB A B) C" (is "?L = ?R")
proof -
  let ?ar = "dim_row A" let ?ac = "dim_col A"
  let ?br = "dim_row B" let ?bc = "dim_col B"
  let ?cr = "dim_row C" let ?cc = "dim_col C"
  let ?r = "?ar + ?br + ?cr" let ?c = "?ac + ?bc + ?cc"
  let ?BC = "FB B C" let ?AB = "FB A B"
  have dL: "dim_row ?L = ?r" "dim_col ?L = ?c" unfolding FB by auto
  have dR: "dim_row ?R = ?ar + ?br + ?cr" "dim_col ?R = ?ac + ?bc + ?cc" unfolding FB by auto
  have dBC: "dim_row ?BC = ?br + ?cr" "dim_col ?BC = ?bc + ?cc" unfolding FB by auto
  have dAB: "dim_row ?AB = ?ar + ?br" "dim_col ?AB = ?ac + ?bc" unfolding FB by auto
  show ?thesis
  proof (intro eq_matI[of ?R ?L, unfolded dL dR, OF _ refl refl])
    fix i j
    assume i: "i < ?r" and j: "j < ?c"
    show "?L $$ (i,j) = ?R $$ (i,j)"
    proof (cases "i < ?ar")
      case True note i = this
      thus ?thesis using j
        by (cases "j < ?ac", auto simp: FB)
    next
      case False note ii = this
      show ?thesis
      proof (cases "j < ?ac")
        case True
        with i ii show ?thesis unfolding FB by auto
      next
        case False note jj = this
        from j jj i ii have L: "?L $$ (i,j) = ?BC $$ (i - ?ar, j - ?ac)" unfolding FB by auto
        have R: "?R $$ (i,j) = ?BC $$ (i - ?ar, j - ?ac)" using ii jj i j
          by (cases "i < ?ar + ?br"; cases "j < ?ac + ?bc", auto simp: FB)
        show ?thesis unfolding L R ..
      qed
    qed
  qed
qed

definition split_block :: "'a mat  nat  nat  ('a mat × 'a mat × 'a mat × 'a mat)"
  where "split_block A sr sc = (let
    nr = dim_row A; nc = dim_col A;
    nr2 = nr - sr; nc2 = nc - sc;
    A1 = mat sr sc (λ ij. A $$ ij);
    A2 = mat sr nc2 (λ (i,j). A $$ (i,j+sc));
    A3 = mat nr2 sc (λ (i,j). A $$ (i+sr,j));
    A4 = mat nr2 nc2 (λ (i,j). A $$ (i+sr,j+sc))
  in (A1,A2,A3,A4))"

lemma split_block: assumes res: "split_block A sr1 sc1 = (A1,A2,A3,A4)"
  and dims: "dim_row A = sr1 + sr2" "dim_col A = sc1 + sc2"
  shows "A1  carrier_mat sr1 sc1" "A2  carrier_mat sr1 sc2"
    "A3  carrier_mat sr2 sc1" "A4  carrier_mat sr2 sc2"
    "A = four_block_mat A1 A2 A3 A4"
  using res unfolding split_block_def Let_def
  by (auto simp: dims)

text ‹Using @{const four_block_mat} we define block-diagonal matrices.›

fun diag_block_mat :: "'a :: zero mat list  'a mat" where
  "diag_block_mat [] = 0m 0 0"
| "diag_block_mat (A # As) = (let
     B = diag_block_mat As
     in four_block_mat A (0m (dim_row A) (dim_col B)) (0m (dim_row B) (dim_col A)) B)"

lemma dim_diag_block_mat:
  "dim_row (diag_block_mat As) = sum_list (map dim_row As)" (is "?row")
  "dim_col (diag_block_mat As) = sum_list (map dim_col As)" (is "?col")
proof -
  have "?row  ?col"
    by (induct As, auto simp: Let_def)
  thus ?row and ?col by auto
qed

lemma diag_block_mat_singleton[simp]: "diag_block_mat [A] = A"
  by auto

lemma diag_block_mat_append: "diag_block_mat (As @ Bs) =
  (let A = diag_block_mat As; B = diag_block_mat Bs
  in four_block_mat A (0m (dim_row A) (dim_col B)) (0m (dim_row B) (dim_col A)) B)"
  unfolding Let_def
proof (induct As)
  case (Cons A As)
  show ?case
    unfolding append.simps
    unfolding diag_block_mat.simps Let_def
    unfolding Cons
    by (rule assoc_four_block_mat)
qed auto

lemma diag_block_mat_last: "diag_block_mat (As @ [B]) =
  (let A = diag_block_mat As
  in four_block_mat A (0m (dim_row A) (dim_col B)) (0m (dim_row B) (dim_col A)) B)"
  unfolding diag_block_mat_append diag_block_mat_singleton by auto


lemma diag_block_mat_square:
  "Ball (set As) square_mat  square_mat (diag_block_mat As)"
by (induct As, auto simp:Let_def)

lemma diag_block_one_mat[simp]:
  "diag_block_mat (map (λA. 1m (dim_row A)) As) = (1m (sum_list (map dim_row As)))"
  by (induct As, auto simp: Let_def)

lemma elements_diag_block_mat:
  "elements_mat (diag_block_mat As)  {0}   (set (map elements_mat As))"
proof (induct As)
  case Nil then show ?case using dim_diag_block_mat[of Nil] by auto next
  case (Cons A As)
    let ?D = "diag_block_mat As"
    let ?B = "0m (dim_row A) (dim_col ?D)"
    let ?C = "0m (dim_row ?D) (dim_col A)"
    have A: "A  carrier_mat (dim_row A) (dim_col A)" by auto
    have B: "?B  carrier_mat (dim_row A) (dim_col ?D)" by auto
    have C: "?C  carrier_mat (dim_row ?D) (dim_col A)" by auto
    have D: "?D  carrier_mat (dim_row ?D) (dim_col ?D)" by auto
    have
      "elements_mat (diag_block_mat (A#As)) 
       elements_mat A  elements_mat ?B  elements_mat ?C  elements_mat ?D"
      unfolding diag_block_mat.simps Let_def
      using elements_four_block_mat[OF A B C D] elements_0_mat
      by auto
    also have "...  {0}  elements_mat A  elements_mat ?D"
      using elements_0_mat by auto
    finally show ?case using Cons by auto
qed

lemma diag_block_pow_mat: assumes sq: "Ball (set As) square_mat"
  shows "diag_block_mat As ^m n = diag_block_mat (map (λ A. A ^m n) As)" (is "?As ^m _ = _")
proof (induct n)
  case 0
  have "?As ^m 0 = 1m (dim_row ?As)" by simp
  also have "dim_row ?As = sum_list (map dim_row As)"
    using diag_block_mat_square[OF sq] unfolding dim_diag_block_mat by auto
  also have "1m  = diag_block_mat (map (λA. 1m (dim_row A)) As)" by simp
  also have " = diag_block_mat (map (λ A. A ^m 0) As)" by simp
  finally show ?case .
next
  case (Suc n)
  let ?An = "λ As. diag_block_mat (map (λA. A ^m n) As)"
  let ?Asn = "λ As. diag_block_mat (map (λA. A ^m n * A) As)"
  from Suc have "?case = (?An As * diag_block_mat As = ?Asn As)" by simp
  also have "" using sq
  proof (induct As)
    case (Cons A As)
    hence IH: "?An As * diag_block_mat As = ?Asn As"
      and sq: "Ball (set As) square_mat" and A: "dim_col A = dim_row A" by auto
    have sq2: "Ball (set (List.map (λA. A ^m n) As)) square_mat"
      and sq3: "Ball (set (List.map (λA. A ^m n * A) As)) square_mat"
      using sq by auto
    define n1 where "n1 = dim_row A"
    define n2 where "n2 = sum_list (map dim_row As)"
    from A have A: "A  carrier_mat n1 n1" unfolding n1_def carrier_mat_def by simp
    have [simp]: "dim_col (?An As) = n2" "dim_row (?An As) = n2"
      unfolding n2_def
      using diag_block_mat_square[OF sq2,unfolded square_mat.simps]
      unfolding dim_diag_block_mat map_map by (auto simp:o_def)
    have [simp]: "dim_col (?Asn As) = n2" "dim_row (?Asn As) = n2"
      unfolding n2_def
      using diag_block_mat_square[OF sq3,unfolded square_mat.simps]
      unfolding dim_diag_block_mat map_map by (auto simp:o_def)
    have [simp]:
      "dim_row (diag_block_mat As) = n2"
      "dim_col (diag_block_mat As) = n2"
      unfolding n2_def
      using diag_block_mat_square[OF sq,unfolded square_mat.simps]
      unfolding dim_diag_block_mat by auto

    have [simp]: "diag_block_mat As  carrier_mat n2 n2" unfolding carrier_mat_def by simp
    have [simp]: "?An As  carrier_mat n2 n2" unfolding carrier_mat_def by simp
    show ?case unfolding diag_block_mat.simps Let_def list.simps
      by (subst mult_four_block_mat[of _ n1 n1 _ n2 _ n2 _ _ n1 _ n2],
      insert A, auto simp: IH)
  qed auto
  finally show ?case by simp
qed

lemma diag_block_upper_triangular: assumes
    " A i j. A  set As  j < i  i < dim_row A  A $$ (i,j) = 0"
  and "Ball (set As) square_mat"
  and "j < i" "i < dim_row (diag_block_mat As)"
  shows "diag_block_mat As $$ (i,j) = 0"
  using assms
proof (induct As arbitrary: i j)
  case (Cons A As i j)
  let ?n1 = "dim_row A"
  let ?n2 = "sum_list (map dim_row As)"
  from Cons have [simp]: "dim_col A = ?n1" by simp
  from Cons have "Ball (set As) square_mat" by auto
  note [simp] = diag_block_mat_square[OF this,unfolded square_mat.simps]
  note [simp] = dim_diag_block_mat(1)
  from Cons(5) have i: "i < ?n1 + ?n2" by simp
  show ?case
  proof (cases "i < ?n1")
    case True
    with Cons(4) have j: "j < ?n1" by auto
    with True Cons(2)[of A, OF _ Cons(4)] show ?thesis
      by (simp add: Let_def)
  next
    case False note iAs = this
    show ?thesis
    proof (cases "j < ?n1")
      case True
      with i iAs show ?thesis by (simp add: Let_def)
    next
      case False note jAs = this
      from Cons(4) i have j: "j < ?n1 + ?n2" by auto
      show ?thesis using iAs jAs i j
        by (simp add: Let_def, subst Cons(1), insert Cons(2-4), auto)
    qed
  qed
qed simp

lemma smult_four_block_mat: assumes c: "A  carrier_mat nr1 nc1" "B  carrier_mat nr1 nc2"
  "C  carrier_mat nr2 nc1" "D  carrier_mat nr2 nc2"
  shows "a m four_block_mat A B C D = four_block_mat (a m A) (a m B) (a m C) (a m D)"
  by (rule eq_matI, insert c, auto)

lemma map_four_block_mat: assumes c: "A  carrier_mat nr1 nc1" "B  carrier_mat nr1 nc2"
  "C  carrier_mat nr2 nc1" "D  carrier_mat nr2 nc2"
  shows "map_mat f (four_block_mat A B C D) = four_block_mat (map_mat f A) (map_mat f B) (map_mat f C) (map_mat f D)"
  by (rule eq_matI, insert c, auto)

lemma add_four_block_mat: assumes
  c1: "A1  carrier_mat nr1 nc1" "B1  carrier_mat nr1 nc2" "C1  carrier_mat nr2 nc1" "D1  carrier_mat nr2 nc2" and
  c2: "A2  carrier_mat nr1 nc1" "B2  carrier_mat nr1 nc2" "C2  carrier_mat nr2 nc1" "D2  carrier_mat nr2 nc2"
  shows "four_block_mat A1 B1 C1 D1 + four_block_mat A2 B2 C2 D2
  = four_block_mat (A1 + A2) (B1 + B2) (C1 + C2) (D1 + D2)"
  by (rule eq_matI, insert assms, auto)


lemma diag_four_block_mat: assumes c: "A  carrier_mat n1 n1"
   "D  carrier_mat n2 n2"
  shows "diag_mat (four_block_mat A B C D) = diag_mat A @ diag_mat D"
  by (rule nth_equalityI, insert c, auto simp: diag_mat_def nth_append)

definition mk_diagonal :: "'a::zero list  'a mat"
  where "mk_diagonal as = diag_block_mat (map (λa. mat (Suc 0) (Suc 0) (λ_. a)) as)"

lemma mk_diagonal_dim:
  "dim_row (mk_diagonal as) = length as" "dim_col (mk_diagonal as) = length as"
  unfolding mk_diagonal_def by(induct as, auto simp: Let_def)

lemma mk_diagonal_diagonal: "diagonal_mat (mk_diagonal as)"
  unfolding mk_diagonal_def
proof (induct as)
  case Nil show ?case unfolding mk_diagonal_def diagonal_mat_def by simp next
  case (Cons a as)
    let ?n = "length (a#as)"
    let ?A = "mat (Suc 0) (Suc 0) (λ_. a)"
    let ?f = "map (λa. mat (Suc 0) (Suc 0) (λ_. a))"
    let ?AS = "diag_block_mat (?f as)"
    let ?AAS = "diag_block_mat (?f (a#as))"
    show ?case
      unfolding diagonal_mat_def
    proof(intro allI impI)
      fix i j assume ir: "i < dim_row ?AAS" and jc: "j < dim_col ?AAS" and ij: "i  j"
      hence ir2: "i < 1 + dim_row ?AS" and jc2: "j < 1 + dim_col ?AS"
        unfolding dim_row_mat list.map diag_block_mat.simps Let_def
        by auto
      show "?AAS $$ (i,j) = 0"
      proof (cases "i = 0")
        case True
          then show ?thesis using jc ij by (auto simp: Let_def) next
        case False note i0 = this
          show ?thesis
          proof (cases "j = 0")
            case True
              then show ?thesis using ir ij by (auto simp: Let_def) next
            case False
              have ir3: "i-1 < dim_row ?AS" and jc3: "j-1 < dim_col ?AS"
                using ir2 jc2 i0 False by auto
              have IH: "i j. i < dim_row ?AS  j < dim_col ?AS  i  j 
                ?AS $$ (i,j) = 0"
                using Cons unfolding diagonal_mat_def by auto
              have "?AS $$ (i-1,j-1) = 0"
                using IH[OF ir3 jc3] i0 False ij by auto
              thus ?thesis using ir jc ij by (simp add: Let_def)
          qed
      qed
    qed
qed

definition orthogonal_mat :: "'a::semiring_0 mat  bool"
  where "orthogonal_mat A 
    let B = transpose_mat A * A in
    diagonal_mat B  (i<dim_col A. B $$ (i,i)  0)"

lemma orthogonal_matD[elim]:
  "orthogonal_mat A 
   i < dim_col A  j < dim_col A  (col A i  col A j = 0) = (i  j)"
  unfolding orthogonal_mat_def diagonal_mat_def by auto

lemma orthogonal_matI[intro]:
  "(i j. i < dim_col A  j < dim_col A  (col A i  col A j = 0) = (i  j)) 
   orthogonal_mat A"
  unfolding orthogonal_mat_def diagonal_mat_def by auto

definition orthogonal :: "'a::semiring_0 vec list  bool"
  where "orthogonal vs 
    i j. i < length vs  j < length vs 
      (vs ! i  vs ! j = 0) = (i  j)"

lemma orthogonalD[elim]:
  "orthogonal vs  i < length vs  j < length vs 
  (nth vs i  nth vs j = 0) = (i  j)"
  unfolding orthogonal_def by auto

lemma orthogonalI[intro]:
  "(i j. i < length vs  j < length vs  (nth vs i  nth vs j = 0) = (i  j)) 
   orthogonal vs"
  unfolding orthogonal_def by auto


lemma transpose_four_block_mat: assumes *: "A  carrier_mat nr1 nc1" "B  carrier_mat nr1 nc2"
  "C  carrier_mat nr2 nc1" "D  carrier_mat nr2 nc2"
  shows "transpose_mat (four_block_mat A B C D) =
    four_block_mat (transpose_mat A) (transpose_mat C) (transpose_mat B) (transpose_mat D)"
  by (rule eq_matI, insert *, auto)

lemma zero_transpose_mat[simp]: "transpose_mat (0m n m) = (0m m n)"
  by (rule eq_matI, auto)

lemma upper_triangular_four_block: assumes AD: "A  carrier_mat n n" "D  carrier_mat m m"
  and ut: "upper_triangular A" "upper_triangular D"
  shows "upper_triangular (four_block_mat A B (0m m n) D)"
proof -
  let ?C = "four_block_mat A B (0m m n) D"
  from AD have dim: "dim_row ?C = n + m" "dim_col ?C = n + m" "dim_row A = n" by auto
  show ?thesis
  proof (rule upper_triangularI, unfold dim)
    fix i j
    assume *: "j < i" "i < n + m"
    show "?C $$ (i,j) = 0"
    proof (cases "i < n")
      case True
      with upper_triangularD[OF ut(1) *(1)] * AD show ?thesis by auto
    next
      case False note i = this
      show ?thesis by (cases "j < n", insert upper_triangularD[OF ut(2)] * i AD, auto)
    qed
  qed
qed

lemma pow_four_block_mat: assumes A: "A  carrier_mat n n"
  and B: "B  carrier_mat m m"
  shows "(four_block_mat A (0m n m) (0m m n) B) ^m k =
    four_block_mat (A ^m k) (0m n m) (0m m n) (B ^m k)"
proof (induct k)
  case (Suc k)
  let ?FB = "λ A B. four_block_mat A (0m n m) (0m m n) B"
  let ?A = "?FB A B"
  let ?B = "?FB (A ^m k) (B ^m k)"
  from A B have Ak: "A ^m k  carrier_mat n n" and Bk: "B ^m k  carrier_mat m m" by auto
  have "?A ^m Suc k = ?A ^m k * ?A" by simp
  also have "?A ^m k = ?B " by (rule Suc)
  also have "?B * ?A = ?FB (A ^m Suc k) (B ^m Suc k)"
    by (subst mult_four_block_mat[OF Ak _ _ Bk A _ _ B], insert A B, auto)
  finally show ?case .
qed (insert A B, auto)

lemma uminus_scalar_prod:
  assumes [simp]: "v : carrier_vec n" "w : carrier_vec n"
  shows "- ((v::'a::field vec)  w) = (- v)  w"
  unfolding scalar_prod_def uminus_vec_def
  apply (subst sum_negf[symmetric])
proof (rule sum.cong[OF refl])
  fix i assume i: "i : {0 ..<dim_vec w}"
  have [simp]: "dim_vec v = n" "dim_vec w = n" by auto
  show "- (v $ i * w $ i) = vec (dim_vec v) (λi. - v $ i) $ i * w $ i"
    unfolding minus_mult_left using i by auto
qed


lemma append_vec_eq:
  assumes [simp]: "v : carrier_vec n" "v' : carrier_vec n"
  shows [simp]: "v @v w = v' @v w'  v = v'  w = w'" (is "?L  ?R")
proof
  have [simp]: "dim_vec v = n" "dim_vec v' = n" by auto
  { assume L: ?L
    have vv': "v = v'"
    proof
      fix i assume i: "i < dim_vec v'"
      have "(v @v w) $ i = (v' @v w') $ i" using L by auto
      thus "v $ i = v' $ i" using i by auto
    qed auto
    moreover have "w = w'"
    proof
      show "dim_vec w = dim_vec w'" using vv' L
        by (metis add_diff_cancel_left' index_append_vec(2))
      moreover fix i assume i: "i < dim_vec w'"
      have "(v @v w) $ (n + i) = (v' @v w') $ (n + i)" using L by auto
      ultimately show "w $ i = w' $ i" using i by simp
    qed
    ultimately show ?R by simp
  }
qed auto

lemma append_vec_add:
  assumes [simp]: "v : carrier_vec n" "v' : carrier_vec n"
      and [simp]: "w : carrier_vec m" "w' : carrier_vec m"
  shows "(v @v w) + (v' @v w') = (v + v') @v (w + w')" (is "?L = ?R")
proof
  have [simp]: "dim_vec v = n" "dim_vec v' = n" by auto
  have [simp]: "dim_vec w = m" "dim_vec w' = m" by auto
  fix i assume i: "i < dim_vec ?R"
  thus "?L $ i = ?R $ i" by (cases "i < n",auto)
qed auto


lemma mult_mat_vec_split:
  assumes A: "A : carrier_mat n n"
      and D: "D : carrier_mat m m"
      and a: "a : carrier_vec n"
      and d: "d : carrier_vec m"
  shows "four_block_mat A (0m n m) (0m m n) D *v (a @v d) = A *v a @v D *v d"
    (is "?A00D *v _ = ?r")
proof
  have A00D: "?A00D : carrier_mat (n+m) (n+m)" using four_block_carrier_mat[OF A D].
  fix i assume i: "i < dim_vec ?r"
  show "(?A00D *v (a @v d)) $ i = ?r $ i" (is "?li = _")
  proof (cases "i < n")
    case True
      have "?li = (row A i @v 0v m)  (a @v d)"
        using A row_four_block_mat[OF A _ _ D] True by simp
      also have "... = row A i  a + 0v m  d"
        apply (rule scalar_prod_append) using A D a d True by auto
      also have "... = row A i  a" using d by simp
      finally show ?thesis using A True by auto
    next case False
      let ?i = "i - n"
      have "?li = (0v n @v row D ?i)  (a @v d)"
        using i row_four_block_mat[OF A _ _ D] False A D by simp
      also have "... = 0v n  a + row D ?i  d"
        apply (rule scalar_prod_append) using A D a d False by auto
      also have "... = row D ?i  d" using a by simp
      finally show ?thesis using A D False i by auto
  qed
qed auto

lemma similar_mat_witI: assumes "P * Q = 1m n" "Q * P = 1m n" "A = P * B * Q"
  "A  carrier_mat n n" "B  carrier_mat n n" "P  carrier_mat n n" "Q  carrier_mat n n"
  shows "similar_mat_wit A B P Q" using assms unfolding similar_mat_wit_def Let_def by auto

lemma similar_mat_witD: assumes "n = dim_row A" "similar_mat_wit A B P Q"
  shows "P * Q = 1m n" "Q * P = 1m n" "A = P * B * Q"
  "A  carrier_mat n n" "B  carrier_mat n n" "P  carrier_mat n n" "Q  carrier_mat n n"
  using assms(2) unfolding similar_mat_wit_def Let_def assms(1)[symmetric] by auto

lemma similar_mat_witD2: assumes "A  carrier_mat n m" "similar_mat_wit A B P Q"
  shows "P * Q = 1m n" "Q * P = 1m n" "A = P * B * Q"
  "A  carrier_mat n n" "B  carrier_mat n n" "P  carrier_mat n n" "Q  carrier_mat n n"
  using similar_mat_witD[OF _ assms(2), of n] assms(1)[unfolded carrier_mat_def] by auto

lemma similar_mat_wit_sym: assumes sim: "similar_mat_wit A B P Q"
  shows "similar_mat_wit B A Q P"
proof -
  from similar_mat_witD[OF refl sim] obtain n where
    AB: "{A, B, P, Q}  carrier_mat n n" "P * Q = 1m n" "Q * P = 1m n" and A: "A = P * B * Q" by blast
  hence *: "{B, A, Q, P}  carrier_mat n n" "Q * P = 1m n" "P * Q = 1m n" by auto
  let ?c = "λ A. A  carrier_mat n n"
  from * have Carr: "?c B" "?c P" "?c Q" by auto
  note [simp] = assoc_mult_mat[of _ n n _ n _ n]
  show ?thesis
  proof (rule similar_mat_witI[of _ _ n])
    have "Q * A * P = (Q * P) * B * (Q * P)"
      using Carr unfolding A by simp
    also have " = B" using Carr unfolding AB by simp
    finally show "B = Q * A * P" by simp
  qed (insert * AB, auto)
qed

lemma similar_mat_wit_refl: assumes A: "A  carrier_mat n n"
  shows "similar_mat_wit A A (1m n) (1m n)"
  by (rule similar_mat_witI[OF _ _ _ A], insert A, auto)

lemma similar_mat_wit_trans: assumes AB: "similar_mat_wit A B P Q"
  and BC: "similar_mat_wit B C P' Q'"
  shows "similar_mat_wit A C (P * P') (Q' * Q)"
proof -
  from similar_mat_witD[OF refl AB] obtain n where
    AB: "{A, B, P, Q}  carrier_mat n n" "P * Q = 1m n" "Q * P = 1m n" "A = P * B * Q" by blast
  hence B: "B  carrier_mat n n" by auto
  from similar_mat_witD2[OF B BC] have
    BC: "{C, P', Q'}  carrier_mat n n" "P' * Q' = 1m n" "Q' * P' = 1m n" "B = P' * C * Q'" by auto
  let ?c = "λ A. A  carrier_mat n n"
  let ?P = "P * P'"
  let ?Q = "Q' * Q"
  from AB BC have carr: "?c A" "?c B" "?c C" "?c P" "?c P'" "?c Q" "?c Q'"
    and Carr: "{A, C, ?P, ?Q}  carrier_mat n n" by auto
  note [simp] = assoc_mult_mat[of _ n n _ n _ n]
  have id: "A = ?P * C * ?Q" unfolding AB(4)[unfolded BC(4)] using carr
    by simp
  have "?P * ?Q = P * (P' * Q') * Q" using carr by simp
  also have " = 1m n" unfolding BC using carr AB by simp
  finally have PQ: "?P * ?Q = 1m n" .
  have "?Q * ?P = Q' * (Q * P) * P'" using carr by simp
  also have " = 1m n" unfolding AB using carr BC by simp
  finally have QP: "?Q * ?P = 1m n" .
  show ?thesis
    by (rule similar_mat_witI[OF PQ QP id], insert Carr, auto)
qed

lemma similar_mat_refl: "A  carrier_mat n n  similar_mat A A"
  using similar_mat_wit_refl unfolding similar_mat_def by blast

lemma similar_mat_trans: "similar_mat A B  similar_mat B C  similar_mat A C"
  using similar_mat_wit_trans unfolding similar_mat_def by blast

lemma similar_mat_sym: "similar_mat A B  similar_mat B A"
  using similar_mat_wit_sym unfolding similar_mat_def by blast

lemma similar_mat_wit_four_block: assumes
      1: "similar_mat_wit A1 B1 P1 Q1"
  and 2: "similar_mat_wit A2 B2 P2 Q2"
  and URA: "URA = (P1 * UR * Q2)"
  and LLA: "LLA = (P2 * LL * Q1)"
  and A1: "A1  carrier_mat n n"
  and A2: "A2  carrier_mat m m"
  and LL: "LL  carrier_mat m n"
  and UR: "UR  carrier_mat n m"
  shows "similar_mat_wit (four_block_mat A1 URA LLA A2) (four_block_mat B1 UR LL B2)
    (four_block_mat P1 (0m n m) (0m m n) P2) (four_block_mat Q1 (0m n m) (0m m n) Q2)"
  (is "similar_mat_wit ?A ?B ?P ?Q")
proof -
  let ?n = "n + m"
  let ?O1 = "1m n"   let ?O2 = "1m m"   let ?O = "1m ?n"
  from similar_mat_witD2[OF A1 1] have 11: "P1 * Q1 = ?O1" "Q1 * P1 = ?O1"
    and P1: "P1  carrier_mat n n" and Q1: "Q1  carrier_mat n n"
    and B1: "B1  carrier_mat n n" and 1: "A1 = P1 * B1 * Q1" by auto
  from similar_mat_witD2[OF A2 2] have 21: "P2 * Q2 = ?O2" "Q2 * P2 = ?O2"
    and P2: "P2  carrier_mat m m" and Q2: "Q2  carrier_mat m m"
    and B2: "B2  carrier_mat m m" and 2: "A2 = P2 * B2 * Q2" by auto
  have PQ1: "?P * ?Q = ?O"
    by (subst mult_four_block_mat[OF P1 _ _ P2 Q1 _ _ Q2], unfold 11 21, insert P1 P2 Q1 Q2,
      auto intro!: eq_matI)
  have QP1: "?Q * ?P = ?O"
    by (subst mult_four_block_mat[OF Q1 _ _ Q2 P1 _ _ P2], unfold 11 21, insert P1 P2 Q1 Q2,
      auto intro!: eq_matI)
  let ?PB = "?P * ?B"
  have P: "?P  carrier_mat ?n ?n" using P1 P2 by auto
  have Q: "?Q  carrier_mat ?n ?n" using Q1 Q2 by auto
  have B: "?B  carrier_mat ?n ?n" using B1 UR LL B2 by auto
  have PB: "?PB  carrier_mat ?n ?n" using P B by auto
  have PB1: "P1 * B1  carrier_mat n n" using P1 B1 by auto
  have PB2: "P2 * B2  carrier_mat m m" using P2 B2 by auto
  have P1UR: "P1 * UR  carrier_mat n m" using P1 UR by auto
  have P2LL: "P2 * LL  carrier_mat m n" using P2 LL by auto
  have id: "?PB = four_block_mat (P1 * B1) (P1 * UR) (P2 * LL) (P2 * B2)"
    by (subst mult_four_block_mat[OF P1 _ _ P2 B1 UR LL B2], insert P1 P2 B1 B2 LL UR, auto)
  have id: "?PB * ?Q = four_block_mat (P1 * B1 * Q1) (P1 * UR * Q2)
    (P2 * LL * Q1) (P2 * B2 * Q2)" unfolding id
    by (subst mult_four_block_mat[OF PB1 P1UR P2LL PB2 Q1 _ _ Q2],
    insert P1 P2 B1 B2 Q1 Q2 UR LL, auto)
  have id: "?A = ?P * ?B * ?Q" unfolding id 1 2 URA LLA ..
  show ?thesis
    by (rule similar_mat_witI[OF PQ1 QP1 id], insert A1 A2 B1 B2 Q1 Q2 P1 P2, auto)
qed


lemma similar_mat_four_block_0_ex: assumes
      1: "similar_mat A1 B1"
  and 2: "similar_mat A2 B2"
  and A0: "A0  carrier_mat n m"
  and A1: "A1  carrier_mat n n"
  and A2: "A2  carrier_mat m m"
  shows " B0. B0  carrier_mat n m  similar_mat (four_block_mat A1 A0 (0m m n) A2)
    (four_block_mat B1 B0 (0m m n) B2)"
proof -
  from 1[unfolded similar_mat_def] obtain P1 Q1 where 1: "similar_mat_wit A1 B1 P1 Q1" by auto
  note w1 = similar_mat_witD2[OF A1 1]
  from 2[unfolded similar_mat_def] obtain P2 Q2 where 2: "similar_mat_wit A2 B2 P2 Q2" by auto
  note w2 = similar_mat_witD2[OF A2 2]
  from w1 w2 have C: "B1  carrier_mat n n" "B2  carrier_mat m m" by auto
  from w1 w2 have id: "0m m n = Q2 * 0m m n * P1" by simp
  let ?wit = "Q1 * A0 * P2"
  from w1 w2 A0 have wit: "?wit  carrier_mat n m" by auto
  from similar_mat_wit_sym[OF similar_mat_wit_four_block[OF similar_mat_wit_sym[OF 1] similar_mat_wit_sym[OF 2]
    refl id C zero_carrier_mat A0]]
  have "similar_mat (four_block_mat A1 A0 (0m m n) A2) (four_block_mat B1 (Q1 * A0 * P2) (0m m n) B2)"
    unfolding similar_mat_def by auto
  thus ?thesis using wit by auto
qed

lemma similar_mat_four_block_0_0: assumes
      1: "similar_mat A1 B1"
  and 2: "similar_mat A2 B2"
  and A1: "A1  carrier_mat n n"
  and A2: "A2  carrier_mat m m"
  shows "similar_mat (four_block_mat A1 (0m n m) (0m m n) A2)
    (four_block_mat B1 (0m n m) (0m m n) B2)"
proof -
  from 1[unfolded similar_mat_def] obtain P1 Q1 where 1: "similar_mat_wit A1 B1 P1 Q1" by auto
  note w1 = similar_mat_witD2[OF A1 1]
  from 2[unfolded similar_mat_def] obtain P2 Q2 where 2: "similar_mat_wit A2 B2 P2 Q2" by auto
  note w2 = similar_mat_witD2[OF A2 2]
  from w1 w2 have C: "B1  carrier_mat n n" "B2  carrier_mat m m" by auto
  from w1 w2 have id: "0m m n = Q2 * 0m m n * P1" by simp
  from w1 w2 have id2: "0m n m = Q1 * 0m n m * P2" by simp
  from similar_mat_wit_sym[OF similar_mat_wit_four_block[OF similar_mat_wit_sym[OF 1] similar_mat_wit_sym[OF 2]
    id2 id C zero_carrier_mat zero_carrier_mat]]
  show ?thesis unfolding similar_mat_def by blast
qed

lemma similar_diag_mat_block_mat: assumes " A B. (A,B)  set Ms  similar_mat A B"
  shows "similar_mat (diag_block_mat (map fst Ms)) (diag_block_mat (map snd Ms))"
  using assms
proof (induct Ms)
  case Nil
  show ?case by (auto intro!: similar_mat_refl[of _ 0])
next
  case (Cons AB Ms)
  obtain A B where AB: "AB = (A,B)" by force
  from Cons(2)[of A B] have simAB: "similar_mat A B" unfolding AB by auto
  from similar_matD[OF this] obtain n where A: "A  carrier_mat n n" and B: "B  carrier_mat n n" by auto
  hence [simp]: "dim_row A = n" "dim_col A = n" "dim_row B = n" "dim_col B = n" by auto
  let ?C = "diag_block_mat (map fst Ms)" let ?D = "diag_block_mat (map snd Ms)"
  from Cons(1)[OF Cons(2)] have simRec: "similar_mat ?C ?D" by auto
  from similar_matD[OF this] obtain m where C: "?C  carrier_mat m m" and D: "?D  carrier_mat m m" by auto
  hence [simp]: "dim_row ?C = m" "dim_col ?C = m" "dim_row ?D = m" "dim_col ?D = m" by auto
  have "similar_mat (diag_block_mat (map fst (AB # Ms))) (diag_block_mat (map snd (AB # Ms)))
    = similar_mat (four_block_mat A (0m n m) (0m m n) ?C) (four_block_mat B (0m n m) (0m m n) ?D)"
    unfolding AB by (simp add: Let_def)
  also have ""
    by (rule similar_mat_four_block_0_0[OF simAB simRec A C])
  finally show ?case .
qed

lemma similar_mat_wit_pow: assumes wit: "similar_mat_wit A B P Q"
  shows "similar_mat_wit (A ^m k) (B ^m k) P Q"
proof -
  define n where "n = dim_row A"
  let ?C = "carrier_mat n n"
  from similar_mat_witD[OF refl wit, folded n_def] have
    A: "A  ?C" and B: "B  ?C" and P: "P  ?C" and Q: "Q  ?C"
    and PQ: "P * Q = 1m n" and QP: "Q * P = 1m n"
    and AB: "A = P * B * Q"
    by auto
  from A B have *: "(A ^m k)  carrier_mat n n" "B ^m k  carrier_mat n n" by auto
  note carr = A B P Q
  have id: "A ^m k = P * B ^m k * Q" unfolding AB
  proof (induct k)
    case 0
    thus ?case using carr by (simp add: PQ)
  next
    case (Suc k)
    define Bk where "Bk = B ^m k"
    have Bk: "Bk  carrier_mat n n" unfolding Bk_def using carr by simp
    have "(P * B * Q) ^m Suc k = (P * Bk * Q) * (P * B * Q)" by (simp add: Suc Bk_def)
    also have " = P * (Bk * (Q * P) * B) * Q"
      using carr Bk by (simp add: assoc_mult_mat[of _ n n _ n _ n])
    also have "Bk * (Q * P) = Bk" unfolding QP using Bk by simp
    finally show ?case unfolding Bk_def by simp
  qed
  show ?thesis
    by (rule similar_mat_witI[OF PQ QP id * P Q])
qed

lemma similar_mat_wit_pow_id: "similar_mat_wit A B P Q  A ^m k = P * B ^m k * Q"
  using similar_mat_wit_pow[of A B P Q k] unfolding similar_mat_wit_def Let_def by blast

subsection‹Homomorphism properties›

context semiring_hom
begin
abbreviation mat_hom :: "'a mat  'b mat" ("math")
  where "math  map_mat hom"

abbreviation vec_hom :: "'a vec  'b vec" ("vech")
  where "vech  map_vec hom"

lemma vec_hom_zero: "vech (0v n) = 0v n"
  by (rule eq_vecI, auto)

lemma mat_hom_one: "math (1m n) = 1m n"
  by (rule eq_matI, auto)

lemma mat_hom_mult: assumes A: "A  carrier_mat nr n" and B: "B  carrier_mat n nc"
  shows "math (A * B) = math A * math B"
proof -
  let ?L = "math (A * B)"
  let ?R = "math A * math B"
  let ?A = "math A"
  let ?B = "math B"
  from A B have id:
    "dim_row ?L = nr" "dim_row ?R = nr"
    "dim_col ?L = nc" "dim_col ?R = nc"  by auto
  show ?thesis
  proof (rule eq_matI, unfold id)
    fix i j
    assume *: "i < nr" "j < nc"
    define I where "I = {0 ..< n}"
    have id: "{0 ..< dim_vec (col ?B j)} = I" "{0 ..< dim_vec (col B j)} = I"
      unfolding I_def using * B by auto
    have finite: "finite I" unfolding I_def by auto
    have I: "I  {0 ..< n}" unfolding I_def by auto
    have "?L $$ (i,j) = hom (row A i  col B j)" using A B * by auto
    also have " = row ?A i  col ?B j" unfolding scalar_prod_def id using finite I
    proof (induct I)
      case (insert k I)
      show ?case unfolding sum.insert[OF insert(1-2)] hom_add hom_mult
        using insert(3-) * A B by auto
    qed simp
    also have " = ?R $$ (i,j)" using A B * by auto
    finally
    show "?L $$ (i, j) = ?R $$ (i, j)" .
  qed auto
qed

lemma mult_mat_vec_hom: assumes A: "A  carrier_mat nr n" and v: "v  carrier_vec n"
  shows "vech (A *v v) = math A *v vech v"
proof -
  let ?L = "vech (A *v v)"
  let ?R = "math A *v vech v"
  let ?A = "math A"
  let ?v = "vech v"
  from A v have id:
    "dim_vec ?L = nr" "dim_vec ?R = nr"
    by auto
  show ?thesis
  proof (rule eq_vecI, unfold id)
    fix i
    assume *: "i < nr"
    define I where "I = {0 ..< n}"
    have id: "{0 ..< dim_vec v} = I" "{0 ..< dim_vec (vech v)} = I"
      unfolding I_def using * v  by auto
    have finite: "finite I" unfolding I_def by auto
    have I: "I  {0 ..< n}" unfolding I_def by auto
    have "?L $ i = hom (row A i  v)" using A v * by auto
    also have " = row ?A i  ?v" unfolding scalar_prod_def id using finite I
    proof (induct I)
      case (insert k I)
      show ?case unfolding sum.insert[OF insert(1-2)] hom_add hom_mult
        using insert(3-) * A v by auto
    qed simp
    also have " = ?R $ i" using A v * by auto
    finally
    show "?L $ i = ?R $ i" .
  qed auto
qed
end

lemma vec_eq_iff: "(x = y) = (dim_vec x = dim_vec y  ( i < dim_vec y. x $ i = y $ i))" (is "?l = ?r")
proof
  assume ?r
  show ?l
    by (rule eq_vecI, insert ?r, auto)
qed simp

lemma mat_eq_iff: "(x = y) = (dim_row x = dim_row y  dim_col x = dim_col y 
  ( i j. i < dim_row y  j < dim_col y  x $$ (i,j) = y $$ (i,j)))" (is "?l = ?r")
proof
  assume ?r
  show ?l
    by (rule eq_matI, insert ?r, auto)
qed simp

lemma (in inj_semiring_hom) vec_hom_zero_iff[simp]: "(vech x = 0v n) = (x = 0v n)"
proof -
  {
    fix i
    assume i: "i < n" "dim_vec x = n"
    hence "vech x $ i = 0  x $ i = 0"
      using index_map_vec(1)[of i x] by simp
  } note main = this
  show ?thesis unfolding vec_eq_iff by (simp, insert main, auto)
qed

lemma (in inj_semiring_hom) mat_hom_inj: "math A = math B  A = B"
  unfolding mat_eq_iff by auto

lemma (in inj_semiring_hom) vec_hom_inj: "vech v = vech w  v = w"
  unfolding vec_eq_iff by auto

lemma (in semiring_hom) mat_hom_pow: assumes A: "A  carrier_mat n n"
  shows "math (A ^m k) = (math A) ^m k"
proof (induct k)
  case (Suc k)
  thus ?case using mat_hom_mult[OF pow_carrier_mat[OF A, of k] A] by simp
qed (simp add: mat_hom_one)

lemma (in semiring_hom) hom_sum_mat: "hom (sum_mat A) = sum_mat (math A)"
proof -
  obtain B where id: "?thesis = (hom (sum (($$) A) B) = sum (($$) (math A)) B)"
    and B: "B  {0..<dim_row A} × {0..<dim_col A}"
  unfolding sum_mat_def by auto
  from B have "finite B"
    using finite_subset by blast
  thus ?thesis unfolding id using B
  proof (induct B)
    case (insert x F)
    show ?case unfolding sum.insert[OF insert(1-2)] hom_add
      using insert(3-) by auto
  qed simp
qed

lemma (in semiring_hom) vec_hom_smult: "vech (ev v v) = hom ev v vech v"
  by (rule eq_vecI, auto simp: hom_distribs)

lemma minus_scalar_prod_distrib: fixes v1 :: "'a :: ring vec"
  assumes v: "v1  carrier_vec n" "v2  carrier_vec n" "v3  carrier_vec n"
  shows "(v1 - v2)  v3 = v1  v3 - v2  v3"
  unfolding minus_add_uminus_vec[OF v(1-2)]
  by (subst add_scalar_prod_distrib[OF v(1)], insert v, auto)

lemma scalar_prod_minus_distrib: fixes v1 :: "'a :: ring vec"
  assumes v: "v1  carrier_vec n" "v2  carrier_vec n" "v3  carrier_vec n"
  shows "v1  (v2 - v3) = v1  v2 - v1  v3"
  unfolding minus_add_uminus_vec[OF v(2-3)]
  by (subst scalar_prod_add_distrib[OF v(1)], insert v, auto)

lemma uminus_add_minus_vec:
  assumes "l  carrier_vec n" "r  carrier_vec n"
  shows "- ((l::'a :: ab_group_add vec) + r) = (- l - r)"
  using assms by auto

lemma minus_add_minus_vec: fixes u :: "'a :: ab_group_add vec"
  assumes "u  carrier_vec n" "v  carrier_vec n" "w  carrier_vec n"
  shows "u - (v + w) = u - v - w"
  using assms by auto

lemma uminus_add_minus_mat:
  assumes "l  carrier_mat nr nc" "r  carrier_mat nr nc"
  shows "- ((l::'a :: ab_group_add mat) + r) = (- l - r)"
  using assms by auto

lemma minus_add_minus_mat: fixes u :: "'a :: ab_group_add mat"
  assumes "u  carrier_mat nr nc" "v  carrier_mat nr nc" "w  carrier_mat nr nc"
  shows "u - (v + w) = u - v - w"
  using assms by auto

lemma uminus_uminus_vec[simp]: "- (- (v::'a:: group_add vec)) = v"
  by auto

lemma uminus_eq_vec[simp]: "- (v::'a:: group_add vec) = - w  v = w"
  by (metis uminus_uminus_vec)

lemma uminus_uminus_mat[simp]: "- (- (A::'a:: group_add mat)) = A"
  by auto

lemma uminus_eq_mat[simp]: "- (A::'a:: group_add mat) = - B  A = B"
  by (metis uminus_uminus_mat)

lemma smult_zero_mat[simp]: "(k :: 'a :: mult_zero) m 0m nr nc = 0m nr nc"
  by (intro eq_matI, auto)

lemma similar_mat_wit_smult: fixes A :: "'a :: comm_ring_1 mat"
  assumes "similar_mat_wit A B P Q"
  shows "similar_mat_wit (k m A) (k m B) P Q"
proof -
  define n where "n = dim_row A"
  note main = similar_mat_witD[OF n_def assms]
  show ?thesis
    by (rule similar_mat_witI[OF main(1-2) _ _ _ main(6-7)], insert main(3-), auto
      simp: mult_smult_distrib mult_smult_assoc_mat[of _ n n _ n])
qed

lemma similar_mat_smult: fixes A :: "'a :: comm_ring_1 mat"
  assumes "similar_mat A B"
  shows "similar_mat (k m A) (k m B)"
  using similar_mat_wit_smult assms unfolding similar_mat_def by blast

definition mat_diag :: "nat  (nat  'a :: zero)  'a mat" where
  "mat_diag n f = Matrix.mat n n (λ (i,j). if i = j then f j else 0)"

lemma mat_diag_dim[simp]: "mat_diag n f  carrier_mat n n"
  unfolding mat_diag_def by auto

lemma mat_diag_mult_left: assumes A: "A  carrier_mat n nr"
  shows "mat_diag n f * A = Matrix.mat n nr (λ (i,j). f i * A $$ (i,j))"
proof (rule eq_matI, insert A, auto simp: mat_diag_def scalar_prod_def, goal_cases)
  case (1 i j)
  thus ?case by (subst sum.remove[of _ i], auto)
qed

lemma mat_diag_mult_right: assumes A: "A  carrier_mat nr n"
  shows "A * mat_diag n f = Matrix.mat nr n (λ (i,j). A $$ (i,j) * f j)"
proof (rule eq_matI, insert A, auto simp: mat_diag_def scalar_prod_def, goal_cases)
  case (1 i j)
  thus ?case by (subst sum.remove[of _ j], auto)
qed

lemma mat_diag_diag[simp]: "mat_diag n f * mat_diag n g = mat_diag n (λ i. f i * g i)"
  by (subst mat_diag_mult_left[of _ n n], auto simp: mat_diag_def)

lemma mat_diag_one[simp]: "mat_diag n (λ x. 1) = 1m n" unfolding mat_diag_def by auto

text ‹Interpret vector as row-matrix›

definition "mat_of_row y = mat 1 (dim_vec y) (λ ij. y $ (snd ij))" 

lemma mat_of_row_carrier[simp,intro]: 
  "y  carrier_vec n  mat_of_row y  carrier_mat 1 n"
  "y  carrier_vec n  mat_of_row y  carrier_mat (Suc 0) n"
  unfolding mat_of_row_def by auto

lemma mat_of_row_dim[simp]: "dim_row (mat_of_row y) = 1" 
  "dim_col (mat_of_row y) = dim_vec y" 
  unfolding mat_of_row_def by auto

lemma mat_of_row_index[simp]: "x < dim_vec y  mat_of_row y $$ (0,x) = y $ x" 
  unfolding mat_of_row_def by auto

lemma row_mat_of_row[simp]: "row (mat_of_row y) 0 = y" 
  by auto

lemma mat_of_row_mult_append_rows: assumes y1: "y1  carrier_vec nr1" 
  and y2: "y2  carrier_vec nr2" 
  and A1: "A1  carrier_mat nr1 nc" 
  and A2: "A2  carrier_mat nr2 nc" 
shows "mat_of_row (y1 @v y2) * (A1 @r A2) = 
  mat_of_row y1 * A1 + mat_of_row y2 * A2" 
proof -
  from A1 A2 have dim: "dim_row A1 = nr1" "dim_row A2 = nr2" by auto
  let ?M1 = "mat_of_row y1" 
  have M1: "?M1  carrier_mat 1 nr1" using y1 by auto
  let ?M2 = "mat_of_row y2" 
  have M2: "?M2  carrier_mat 1 nr2" using y2 by auto
  let ?M3 = "0m 0 nr1" 
  let ?M4 = "0m 0 nr2" 
  note z = zero_carrier_mat
  have id: "mat_of_row (y1 @v y2) = four_block_mat 
    ?M1 ?M2 ?M3 ?M4" using y1 y2 
    by (intro eq_matI, auto simp: mat_of_rows_def)
  show ?thesis
    unfolding id append_rows_def dim
    by (subst mult_four_block_mat[OF M1 M2 z z A1 z A2 z], insert A1 A2, auto)
qed


text ‹Allowing to construct and deconstruct vectors like lists›
abbreviation vNil where "vNil  vec 0 ((!) [])"
definition vCons where "vCons a v  vec (Suc (dim_vec v)) (λi. case i of 0  a | Suc i  v $ i)"

lemma vec_index_vCons_0 [simp]: "vCons a v $ 0 = a"
  by (simp add: vCons_def)

lemma vec_index_vCons_Suc [simp]:
  fixes v :: "'a vec"
  shows "vCons a v $ Suc n = v $ n"
proof-
  have 1: "vec (Suc d) f $ Suc n = vec d (f  Suc) $ n" for d and f :: "nat  'a"
    by (transfer, auto simp: mk_vec_def)
  show ?thesis
    apply (auto simp: 1 vCons_def o_def) apply transfer apply (auto simp: mk_vec_def)
    done
qed

lemma vec_index_vCons: "vCons a v $ n = (if n = 0 then a else v $ (n - 1))"
  by (cases n, auto)

lemma dim_vec_vCons [simp]: "dim_vec (vCons a v) = Suc (dim_vec v)"
  by (simp add: vCons_def)

lemma vCons_carrier_vec[simp]: "vCons a v  carrier_vec (Suc n)  v  carrier_vec n"
  by (auto dest!: carrier_vecD intro: carrier_vecI)

lemma vec_Suc: "vec (Suc n) f = vCons (f 0) (vec n (f  Suc))" (is "?l = ?r")
proof (unfold vec_eq_iff, intro conjI allI impI)
  fix i assume "i < dim_vec ?r"
  then show "?l $ i = ?r $ i" by (cases i, auto)
qed simp

declare Abs_vec_cases[cases del]

lemma vec_cases [case_names vNil vCons, cases type: vec]:
  assumes "v = vNil  thesis" and "a w. v = vCons a w  thesis"
  shows "thesis"
proof (cases "dim_vec v")
  case 0 then show thesis by (intro assms(1), auto)
next
  case (Suc n)
  show thesis
  proof (rule assms(2))
    show v: "v = vCons (v $ 0) (vec n (λi. v $ Suc i))" (is "v = ?r")
    proof (rule eq_vecI, unfold dim_vec_vCons dim_vec Suc)
      fix i
      assume "i < Suc n"
      then show "v $ i = ?r $ i" by (cases i, auto simp: vCons_def)
    qed simp
  qed
qed

lemma vec_induct [case_names vNil vCons, induct type: vec]:
  assumes "P vNil" and "a v. P v  P (vCons a v)"
  shows "P v"
proof (induct "dim_vec v" arbitrary:v)
  case 0 then show ?case by (cases v, auto intro: assms(1))
next
  case (Suc n) then show ?case by (cases v, auto intro: assms(2))
qed

lemma carrier_vec_induct [consumes 1, case_names 0 Suc, induct set:carrier_vec]:
  assumes v: "v  carrier_vec n"
    and 1: "P 0 vNil" and 2: "n a v. v  carrier_vec n  P n v  P (Suc n) (vCons a v)"
  shows "P n v"
proof (insert v, induct n arbitrary: v)
  case 0 then have "v = vec 0 ((!) [])" by auto
  with 1 show ?case by auto
next
  case (Suc n) then show ?case by (cases v, auto dest!: carrier_vecD intro:2)
qed

lemma vec_of_list_Cons[simp]: "vec_of_list (a#as) = vCons a (vec_of_list as)"
  by (unfold vCons_def, transfer, auto simp:mk_vec_def split:nat.split)

lemma vec_of_list_Nil[simp]: "vec_of_list [] = vNil"
  by (transfer', auto)

lemma scalar_prod_vCons[simp]:
  "vCons a v  vCons b w = a * b + v  w"
  apply (unfold scalar_prod_def atLeast0_lessThan_Suc_eq_insert_0 dim_vec_vCons)
  apply (subst sum.insert) apply (simp,simp)
  apply (subst sum.reindex) apply force
  apply simp
  done

lemma zero_vec_Suc: "0v (Suc n) = vCons 0 (0v n)"
  by (auto simp: zero_vec_def vec_Suc o_def)

lemma zero_vec_zero[simp]: "0v 0 = vNil" by auto

lemma vCons_eq_vCons[simp]: "vCons a v = vCons b w  a = b  v = w" (is "?l  ?r")
proof
  assume ?l
  note arg_cong[OF this]
  from this[of dim_vec] this[of "λx. x$0"] this[of "λx. x$Suc _"]
  show ?r by (auto simp: vec_eq_iff)
qed simp

lemma vec_carrier_vec[simp]: "vec n f  carrier_vec m  n = m"
  unfolding carrier_vec_def by auto

notation transpose_mat ("(_T)" [1000])

lemma map_mat_transpose: "(map_mat f A)T = map_mat f AT" by auto

lemma cols_transpose[simp]: "cols AT = rows A" unfolding cols_def rows_def by auto
lemma rows_transpose[simp]: "rows AT = cols A" unfolding cols_def rows_def by auto
lemma list_of_vec_vec [simp]: "list_of_vec (vec n f) = map f [0..<n]"
  by (transfer, auto simp: mk_vec_def)

lemma list_of_vec_0 [simp]: "list_of_vec (0v n) = replicate n 0"
  by (simp add: zero_vec_def map_replicate_trivial)

lemma diag_mat_map:
  assumes M_carrier: "M  carrier_mat n n"
  shows "diag_mat (map_mat f M) = map f (diag_mat M)"
proof -
  have dim_eq: "dim_row M = dim_col M" using M_carrier by auto
  have m: "map_mat f M $$ (i, i) = f (M $$ (i, i))" if i: "i < dim_row M" for i
    using dim_eq i by auto
  show ?thesis
    by (rule nth_equalityI, insert m, auto simp add: diag_mat_def M_carrier)
qed

lemma mat_of_rows_map [simp]:
  assumes x: "set vs  carrier_vec n"
  shows "mat_of_rows n (map (map_vec f) vs) = map_mat f (mat_of_rows n vs)"
proof-
  have "xset vs. dim_vec x = n" using x by auto
  then show ?thesis by (auto simp add: mat_eq_iff map_vec_def mat_of_rows_def)
qed

lemma mat_of_cols_map [simp]:
  assumes x: "set vs  carrier_vec n"
  shows "mat_of_cols n (map (map_vec f) vs) = map_mat f (mat_of_cols n vs)"
proof-
  have "xset vs. dim_vec x = n" using x by auto
  then show ?thesis by (auto simp add: mat_eq_iff map_vec_def mat_of_cols_def)
qed

lemma vec_of_list_map [simp]: "vec_of_list (map f xs) = map_vec f (vec_of_list xs)"
  unfolding map_vec_def by (transfer, auto simp add: mk_vec_def)

lemma map_vec: "map_vec f (vec n g) = vec n (f o g)" by auto

lemma mat_of_cols_Cons_index_0: "i < n  mat_of_cols n (w # ws) $$ (i, 0) = w $ i"
  by (unfold mat_of_cols_def, transfer', auto simp: mk_mat_def)

lemma nth_map_out_of_bound: "i  length xs  map f xs ! i = [] ! (i - length xs)"
  by (induct xs arbitrary:i, auto)

lemma mat_of_cols_Cons_index_Suc:
  "i < n  mat_of_cols n (w # ws) $$ (i, Suc j) = mat_of_cols n ws $$ (i,j)"
  by (unfold mat_of_cols_def, transfer, auto simp: mk_mat_def undef_mat_def nth_append nth_map_out_of_bound)

lemma mat_of_cols_index: "i < n  j < length ws  mat_of_cols n ws $$ (i,j) = ws ! j $ i"
  by (unfold mat_of_cols_def, auto)

lemma mat_of_rows_index: "i < length rs  j < n  mat_of_rows n rs $$ (i,j) = rs ! i $ j"
  by (unfold mat_of_rows_def, auto)

lemma transpose_mat_of_rows: "(mat_of_rows n vs)T = mat_of_cols n vs"
  by (auto intro!: eq_matI simp: mat_of_rows_index mat_of_cols_index)

lemma transpose_mat_of_cols: "(mat_of_cols n vs)T = mat_of_rows n vs"
  by (auto intro!: eq_matI simp: mat_of_rows_index mat_of_cols_index)

lemma nth_list_of_vec [simp]:
  assumes "i < dim_vec v" shows "list_of_vec v ! i = v $ i"
  using assms by (transfer, auto)

lemma length_list_of_vec [simp]:
  "length (list_of_vec v) = dim_vec v" by (transfer, auto)

lemma vec_eq_0_iff:
  "v = 0v n  n = dim_vec v  (n = 0  set (list_of_vec v) = {0})" (is "?l  ?r")
proof
  show "?l  ?r" by auto
  show "?r  ?l" by (intro iffI eq_vecI, force simp: set_conv_nth, force)
qed

lemma list_of_vec_vCons[simp]: "list_of_vec (vCons a v) = a # list_of_vec v" (is "?l = ?r")
proof (intro nth_equalityI)
  fix i
  assume "i < length ?l"
  then show "?l ! i = ?r ! i" by (cases i, auto)
qed simp

lemma append_vec_vCons[simp]: "vCons a v @v w = vCons a (v @v w)" (is "?l = ?r")
proof (unfold vec_eq_iff, intro conjI allI impI)
  fix i assume "i < dim_vec ?r"
  then show "?l $ i = ?r $ i" by (cases i; subst index_append_vec, auto)
qed simp

lemma append_vec_vNil[simp]: "vNil @v v = v"
  by (unfold vec_eq_iff, auto)

lemma list_of_vec_append[simp]: "list_of_vec (v @v w) = list_of_vec v @ list_of_vec w"
  by (induct v, auto)

lemma transpose_mat_eq[simp]: "AT = BT  A = B"
  using transpose_transpose by metis

lemma mat_col_eqI: assumes cols: " i. i < dim_col B  col A i = col B i"
  and dims: "dim_row A = dim_row B" "dim_col A = dim_col B"
shows "A = B"
  by(subst transpose_mat_eq[symmetric], rule eq_rowI,insert assms,auto)

lemma upper_triangular_imp_distinct:
  assumes A: "A  carrier_mat n n"
    and tri: "upper_triangular A"
    and diag: "0  set (diag_mat A)"
  shows "distinct (rows A)"
proof-
  { fix i and j
    assume eq: "rows A ! i = rows A ! j" and ij: "i < j" and jn: "j < n"
    from tri A ij jn have "rows A ! j $ i = 0" by (auto dest!:upper_triangularD)
    with eq have "rows A ! i $ i = 0" by auto
    with diag ij jn A have False by (auto simp: diag_mat_def)
  }
  with A show ?thesis by (force simp: distinct_conv_nth nat_neq_iff)
qed

lemma dim_vec_of_list[simp] :"dim_vec (vec_of_list as) = length as" by transfer auto

lemma list_vec: "list_of_vec (vec_of_list xs) = xs"
by (transfer, metis (mono_tags, lifting) atLeastLessThan_iff map_eq_conv map_nth mk_vec_def old.prod.case set_upt)

lemma vec_list: "vec_of_list (list_of_vec v) = v"
apply transfer unfolding mk_vec_def by auto

lemma index_vec_of_list: "i<length xs  (vec_of_list xs) $ i = xs ! i"
by (metis vec.abs_eq index_vec vec_of_list.abs_eq)

lemma vec_of_list_index: "vec_of_list xs $ j = xs ! j"
  apply transfer unfolding mk_vec_def unfolding undef_vec_def
  by (simp, metis append_Nil2 nth_append)

lemma list_of_vec_index: "list_of_vec v ! j = v $ j"
  by (metis vec_list vec_of_list_index)

lemma list_of_vec_map: "list_of_vec xs = map (($) xs) [0..<dim_vec xs]" by transfer auto

definition "component_mult v w = vec (min (dim_vec v) (dim_vec w)) (λi. v $ i * w $ i)"
definition vec_set::"'a vec  'a set" ("setv")
  where "vec_set v = vec_index v ` {..<dim_vec v}"

lemma index_component_mult:
assumes "i < dim_vec v" "i < dim_vec w"
shows "component_mult v w $ i = v $ i * w $ i"
  unfolding component_mult_def using assms index_vec by auto

lemma dim_component_mult:
"dim_vec (component_mult v w) = min (dim_vec v) (dim_vec w)"
  unfolding component_mult_def using index_vec by auto

lemma vec_setE:
assumes "a  setv v"
obtains i where "v$i = a" "i<dim_vec v" using assms unfolding vec_set_def by blast

lemma vec_setI:
assumes "v$i = a" "i<dim_vec v"
shows "a  setv v" using assms unfolding vec_set_def using image_eqI lessThan_iff by blast

lemma set_list_of_vec: "set (list_of_vec v) = setv v" unfolding vec_set_def by transfer auto


instantiation vec :: (conjugate) conjugate
begin

definition conjugate_vec :: "'a :: conjugate vec  'a vec"
  where "conjugate v = vec (dim_vec v) (λi. conjugate (v $ i))"

lemma conjugate_vCons [simp]:
  "conjugate (vCons a v) = vCons (conjugate a) (conjugate v)"
  by (auto simp: vec_Suc conjugate_vec_def)

lemma dim_vec_conjugate[simp]: "dim_vec (conjugate v) = dim_vec v"
  unfolding conjugate_vec_def by auto

lemma carrier_vec_conjugate[simp]: "v  carrier_vec n  conjugate v  carrier_vec n"
  by (auto intro!: carrier_vecI)

lemma vec_index_conjugate[simp]:
  shows "i < dim_vec v  conjugate v $ i = conjugate (v $ i)"
  unfolding conjugate_vec_def by auto

instance
proof
  fix v w :: "'a vec"
  show "conjugate (conjugate v) = v" by (induct v, auto simp: conjugate_vec_def)
  let ?v = "conjugate v"
  let ?w = "conjugate w"
  show "conjugate v = conjugate w  v = w"
  proof(rule iffI)
    assume cvw: "?v = ?w" show "v = w"
    proof(rule)
      have "dim_vec ?v = dim_vec ?w" using cvw by auto
      then show dim: "dim_vec v = dim_vec w" by simp
      fix i assume i: "i < dim_vec w"
      then have "conjugate v $ i = conjugate w $ i" using cvw by auto
      then have "conjugate (v$i) = conjugate (w $ i)" using i dim by auto
      then show "v $ i = w $ i" by auto
    qed
  qed auto
qed

end

lemma conjugate_add_vec:
  fixes v w :: "'a :: conjugatable_ring vec"
  assumes dim: "v : carrier_vec n" "w : carrier_vec n"
  shows "conjugate (v + w) = conjugate v + conjugate w"
  by (rule, insert dim, auto simp: conjugate_dist_add)

lemma uminus_conjugate_vec:
  fixes v w :: "'a :: conjugatable_ring vec"
  shows "- (conjugate v) = conjugate (- v)"
  by (rule, auto simp:conjugate_neg)

lemma conjugate_zero_vec[simp]:
  "conjugate (0v n :: 'a :: conjugatable_ring vec) = 0v n" by auto

lemma conjugate_vec_0[simp]:
  "conjugate (vec 0 f) = vec 0 f" by auto

lemma sprod_vec_0[simp]: "v  vec 0 f = 0"
  by(auto simp: scalar_prod_def)

lemma conjugate_zero_iff_vec[simp]:
  fixes v :: "'a :: conjugatable_ring vec"
  shows "conjugate v = 0v n  v = 0v n"
  using conjugate_cancel_iff[of _ "0v n :: 'a vec"] by auto

lemma conjugate_smult_vec:
  fixes k :: "'a :: conjugatable_ring"
  shows "conjugate (k v v) = conjugate k v conjugate v"
  using conjugate_dist_mul by (intro eq_vecI, auto)

lemma conjugate_sprod_vec:
  fixes v w :: "'a :: conjugatable_ring vec"
  assumes v: "v : carrier_vec n" and w: "w : carrier_vec n"
  shows "conjugate (v  w) = conjugate v  conjugate w"
proof (insert w v, induct w arbitrary: v rule:carrier_vec_induct)
  case 0 then show ?case by (cases v, auto)
next
  case (Suc n b w) then show ?case
    by (cases v, auto dest: carrier_vecD simp:conjugate_dist_add conjugate_dist_mul)
qed 

abbreviation cscalar_prod :: "'a vec  'a vec  'a :: conjugatable_ring" (infix "∙c" 70)
  where "(∙c)  λv w. v  conjugate w"

lemma conjugate_conjugate_sprod[simp]:
  assumes v[simp]: "v : carrier_vec n" and w[simp]: "w : carrier_vec n"
  shows "conjugate (conjugate v  w) = v ∙c w"
  apply (subst conjugate_sprod_vec[of _ n]) by auto

lemma conjugate_vec_sprod_comm:
  fixes v w :: "'a :: {conjugatable_ring, comm_ring} vec"
  assumes "v : carrier_vec n" and "w : carrier_vec n"
  shows "v ∙c w = (conjugate w  v)"
  unfolding scalar_prod_def using assms by(subst sum.ivl_cong, auto simp: ac_simps)

lemma conjugate_square_ge_0_vec[intro!]:
  fixes v :: "'a :: conjugatable_ordered_ring vec"
  shows "v ∙c v  0"
proof (induct v)
  case vNil
  then show ?case by auto
next
  case (vCons a v)
  then show ?case using conjugate_square_positive[of a] by auto
qed

lemma conjugate_square_eq_0_vec[simp]:
  fixes v :: "'a :: {conjugatable_ordered_ring,semiring_no_zero_divisors} vec"
  assumes "v  carrier_vec n"
  shows "v ∙c v = 0  v = 0v n"
proof (insert assms, induct rule: carrier_vec_induct)
  case 0
  then show ?case by auto
next
  case (Suc n a v)
  then show ?case
    using conjugate_square_positive[of a] conjugate_square_ge_0_vec[of v]
    by (auto simp: le_less add_nonneg_eq_0_iff zero_vec_Suc)
qed

lemma conjugate_square_greater_0_vec[simp]:
  fixes v :: "'a :: {conjugatable_ordered_ring,semiring_no_zero_divisors} vec"
  assumes "v  carrier_vec n"
  shows "v ∙c v > 0  v  0v n"
  using assms by (auto simp: less_le)

lemma vec_conjugate_rat[simp]: "(conjugate :: rat vec  rat vec) = (λx. x)" by force
lemma vec_conjugate_real[simp]: "(conjugate :: real vec  real vec) = (λx. x)" by force


end

Theory Matrix_IArray_Impl

(*  
    Author:      René Thiemann 
                 Akihisa Yamada
    License:     BSD
*)
section ‹Code Generation for Basic Matrix Operations›

text ‹In this theory we implement matrices as arrays of arrays.
  Due to the target language serialization, access to matrix
  entries should be constant time. Hence operations like
  matrix addition, multiplication, etc.~should all have their 
  standard complexity. 

  There might be room for optimizations. 

  To implement the infinite carrier set, we use A.\ Lochbihler's container framework
  \cite{Containers-AFP}.›

theory Matrix_IArray_Impl
imports
  Matrix
  "HOL-Library.IArray"
  Containers.Set_Impl
begin

typedef 'a vec_impl = "{(n,v :: 'a iarray). IArray.length v = n}" by auto
typedef 'a mat_impl = "{(nr,nc,m :: 'a iarray iarray). 
  IArray.length m = nr  IArray.all (λ r. IArray.length r = nc) m}" 
  by (rule exI[of _ "(0,0,IArray [])"], auto)

setup_lifting type_definition_vec_impl
setup_lifting type_definition_mat_impl

lift_definition vec_impl :: "'a vec_impl  'a vec" is
  "λ (n,v). (n,mk_vec n (IArray.sub v))" by auto

lift_definition vec_add_impl :: "'a::plus vec_impl  'a vec_impl  'a vec_impl" is
  "λ (n,v) (m,w).
  (n, IArray.of_fun (λi. IArray.sub v i + IArray.sub w i) n)"
by auto

lift_definition mat_impl :: "'a mat_impl  'a mat" is
  "λ (nr,nc,m). (nr,nc,mk_mat nr nc (λ (i,j). IArray.sub (IArray.sub m i) j))" by auto

lift_definition vec_of_list_impl :: "'a list  'a vec_impl" is
  "λ v. (length v, IArray v)" by auto

lift_definition list_of_vec_impl :: "'a vec_impl  'a list" is
  "λ (n,v). IArray.list_of v" .
  
lift_definition vec_of_fun :: "nat  (nat  'a)  'a vec_impl" is
  "λ n f. (n, IArray.of_fun f n)" by auto

lift_definition mat_of_fun :: "nat  nat  (nat × nat  'a)  'a mat_impl" is
  "λ nr nc f. (nr, nc, IArray.of_fun (λ i. IArray.of_fun (λ j. f (i,j)) nc) nr)" by auto

lift_definition vec_index_impl :: "'a vec_impl  nat  'a"
  is "λ (n,v). IArray.sub v" .

lift_definition index_mat_impl :: "'a mat_impl  nat × nat  'a"
  is "λ (nr,nc,m) (i,j). if i < nr then IArray.sub (IArray.sub m i) j 
    else IArray.sub (IArray ([] ! (i - nr))) j" .

lift_definition vec_equal_impl :: "'a vec_impl  'a vec_impl  bool"
  is "λ (n1,v1) (n2,v2). n1 = n2  v1 = v2" .

lift_definition mat_equal_impl :: "'a mat_impl  'a mat_impl  bool"
  is "λ (nr1,nc1,m1) (nr2,nc2,m2). nr1 = nr2  nc1 = nc2  m1 = m2" .

lift_definition dim_vec_impl :: "'a vec_impl  nat" is fst .

lift_definition dim_row_impl :: "'a mat_impl  nat" is fst .
lift_definition dim_col_impl :: "'a mat_impl  nat" is "fst o snd" .

code_datatype vec_impl
code_datatype mat_impl

lemma vec_code[code]: "vec n f = vec_impl (vec_of_fun n f)"
  by (transfer, auto simp: mk_vec_def)

lemma mat_code[code]: "mat nr nc f = mat_impl (mat_of_fun nr nc f)"
  by (transfer, auto simp: mk_mat_def, intro ext, clarsimp, 
  auto intro: undef_cong_mat)

lemma vec_of_list[code]: "vec_of_list v = vec_impl (vec_of_list_impl v)"
  by (transfer, auto simp: mk_vec_def)

lemma list_of_vec_code[code]: "list_of_vec (vec_impl v) = list_of_vec_impl v"
  by (transfer, auto simp: mk_vec_def, case_tac b, auto intro: nth_equalityI)

lemma empty_nth: "¬ i < length x  x ! i = [] ! (i - length x)"
  by (metis append_Nil2 nth_append)

lemma undef_vec: "¬ i < length x  undef_vec (i - length x) = x ! i"
  unfolding undef_vec_def by (rule empty_nth[symmetric])
  
lemma vec_index_code[code]: "(vec_impl v) $ i = vec_index_impl v i"
  by (transfer, auto simp: mk_vec_def, case_tac b, auto simp: undef_vec)

lemma index_mat_code[code]: "(mat_impl m) $$ ij = (index_mat_impl m ij :: 'a)"
proof (transfer, unfold o_def, clarify)
  fix m :: "'a iarray iarray" and i j nc
  assume all: "IArray.all (λr. IArray.length r = nc) m"
  obtain mm where m: "m = IArray mm" by (cases m)
  with all have all: " v. v  set mm  IArray.length v = nc" by auto
  show "snd (snd (IArray.length m, nc, mk_mat (IArray.length m) nc (λ(i, y). m !! i !! y))) (i, j) =
     (if i < IArray.length m then m !! i !! j
        else IArray ([] ! (i - IArray.length m)) !! j)" (is "?l = ?r")
  proof (cases "i < length mm")
    case False
    hence " f. ¬ i < length (map f [0..<length mm])" by simp
    note [simp] = empty_nth[OF this]
    have "?l = [] ! (i - length mm) ! j" using False unfolding m mk_mat_def undef_mat_def by simp
    also have " = ?r" unfolding m by (simp add: False empty_nth[OF False])
    finally show ?thesis .
  next
    case True
    obtain v where mm: "mm ! i = IArray v" by (cases "mm ! i")
    with True all[of "mm ! i"] have len: "length v = nc" unfolding set_conv_nth by force
    from mm True have "?l = map ((!) v) [0..<nc] ! j" (is "_ = ?m") unfolding m mk_mat_def undef_mat_def by simp
    also have "?m = m !! i !! j"
    proof (cases "j < length v")
      case True
      thus ?thesis unfolding m using mm len by auto
    next
      case False
      hence j: "¬ j < length (map ((!) v) [0..<length v])" by simp
      show ?thesis unfolding m using mm len by (auto simp: empty_nth[OF j] empty_nth[OF False])
    qed
    also have " = ?r" using True m by simp
    finally show ?thesis .
  qed
qed

lift_definition (code_dt) mat_of_rows_list_impl :: "nat  'a list list  'a mat_impl option" is
  "λ n rows. if list_all (λ r. length r = n) rows then Some (length rows, n, IArray (map IArray rows)) 
  else None" 
  by (auto split: if_splits simp: list_all_iff)

lemma mat_of_rows_list_impl: "mat_of_rows_list_impl n rs = Some A  mat_impl A = mat_of_rows_list n rs" 
  unfolding mat_of_rows_list_def
  by (transfer, auto split: if_splits simp: list_all_iff intro!: cong_mk_mat)
  
lemma mat_of_rows_list_code[code]: "mat_of_rows_list nc vs = 
  (case mat_of_rows_list_impl nc vs of Some A  mat_impl A 
  | None  mat_of_rows nc (map (λ v. vec nc (nth v)) vs))"
proof (cases "mat_of_rows_list_impl nc vs")
  case (Some A)
  from mat_of_rows_list_impl[OF this] show ?thesis unfolding Some by simp
next
  case None
  show ?thesis unfolding None unfolding mat_of_rows_list_def mat_of_rows_def
    by (intro eq_matI, auto)  
qed

lemma dim_vec_code[code]: "dim_vec (vec_impl v) = dim_vec_impl v"
  by (transfer, auto)

lemma dim_row_code[code]: "dim_row (mat_impl m) = dim_row_impl m"
  by (transfer, auto)

lemma dim_col_code[code]: "dim_col (mat_impl m) = dim_col_impl m"
  by (transfer, auto)

instantiation vec :: (type)equal
begin
  definition "(equal_vec :: ('a vec  'a vec  bool)) = (=)"
instance
  by (intro_classes, auto simp: equal_vec_def)
end

instantiation mat :: (type)equal
begin
  definition "(equal_mat :: ('a mat  'a mat  bool)) = (=)"
instance
  by (intro_classes, auto simp: equal_mat_def)
end

lemma veq_equal_code[code]: "HOL.equal (vec_impl (v1 :: 'a vec_impl)) (vec_impl v2) = vec_equal_impl v1 v2"
proof - 
  {
    fix x1 x2 :: "'a list"
    assume len: "length x1 = length x2"
       and index: "(λi. if i < length x2 then IArray x1 !! i else undef_vec (i - length (IArray.list_of (IArray x1)))) =
            (λi. if i < length x2 then IArray x2 !! i else undef_vec (i - length (IArray.list_of (IArray x2))))"    
    have "x1 = x2"
    proof (intro nth_equalityI[OF len])
      fix i
      assume "i < length x1"
      with fun_cong[OF index, of i] len show "x1 ! i = x2 ! i" by simp
    qed
  } note * = this
  show ?thesis unfolding equal_vec_def
    by (transfer, insert *, auto simp: mk_vec_def, case_tac b, case_tac ba, auto)
qed

lemma mat_equal_code[code]: "HOL.equal (mat_impl (m1 :: 'a mat_impl)) (mat_impl m2) = mat_equal_impl m1 m2"
proof - 
  show ?thesis unfolding equal_mat_def
  proof (transfer, auto, case_tac b, case_tac ba, auto)
    fix x1 x2 :: "'a iarray list" and nc
    assume len: "rset x1. length (IArray.list_of r) = nc"
      "rset x2. length (IArray.list_of r) = nc"
      "length x1 = length x2"
    and index: "mk_mat (length x2) nc (λ(i, j). x1 ! i !! j) = mk_mat (length x2) nc (λ(i, j). x2 ! i !! j)"
    show "x1 = x2"
    proof (rule nth_equalityI[OF len(3)])
      fix i
      assume i: "i < length x1"
      obtain ia1 where 1: "x1 ! i = IArray ia1" by (cases "x1 ! i")
      obtain ia2 where 2: "x2 ! i = IArray ia2" by (cases "x2 ! i")
      from i 1 len(1) have l1: "length ia1 = nc" using nth_mem by fastforce
      from i 2 len(2-3) have l2: "length ia2 = nc" using nth_mem by fastforce
      from l1 l2 have l: "length ia1 = length ia2" by simp
      show "x1 ! i = x2 ! i" unfolding 1 2
      proof (simp, rule nth_equalityI[OF l])
        fix j
        assume j: "j < length ia1"
        with fun_cong[OF index, of "(i,j)"] i len(3)
        have "x1 ! i !! j = x2 ! i !! j"
          by (simp add: mk_mat_def l1)
        thus "ia1 ! j = ia2 ! j" unfolding 1 2 by simp
      qed
    qed
  qed
qed  

declare prod.set_conv_list[code del, code_unfold]

derive (eq) ceq mat vec
derive (no) ccompare mat vec
derive (dlist) set_impl mat vec
derive (no) cenum mat vec

lemma carrier_mat_code[code]: "carrier_mat nr nc = Collect_set (λ A. dim_row A = nr  dim_col A = nc)" by auto
lemma carrier_vec_code[code]: "carrier_vec n = Collect_set (λ v. dim_vec v = n)" 
  unfolding carrier_vec_def by auto

end

Theory Gauss_Jordan_Elimination

(*  
    Author:      René Thiemann 
                 Akihisa Yamada
    License:     BSD
*)
section ‹Gauss-Jordan Algorithm›

text ‹We define the elementary row operations and use them to implement the
  Gauss-Jordan algorithm to transform matrices into row-echelon-form. 
  This algorithm is used to implement the inverse of a matrix and to derive
  certain results on determinants, as well as determine a basis of the kernel
  of a matrix.› 

theory Gauss_Jordan_Elimination
imports Matrix
begin

subsection ‹Row Operations›

definition mat_multrow_gen :: "('a  'a  'a)  nat  'a  'a mat  'a mat" where
  "mat_multrow_gen mul k a A = mat (dim_row A) (dim_col A) 
     (λ (i,j). if k = i then mul a (A $$ (i,j)) else A $$ (i,j))"

abbreviation mat_multrow :: "nat  'a :: semiring_1  'a mat  'a mat" ("multrow") where
  "multrow  mat_multrow_gen ((*))"

lemmas mat_multrow_def = mat_multrow_gen_def

definition multrow_mat :: "nat  nat  'a :: semiring_1  'a mat" where
  "multrow_mat n k a = mat n n 
     (λ (i,j). if k = i  k = j then a else if i = j then 1 else 0)"

definition mat_swaprows :: "nat  nat  'a mat  'a mat" ("swaprows")where
  "swaprows k l A = mat (dim_row A) (dim_col A) 
    (λ (i,j). if k = i then A $$ (l,j) else if l = i then A $$ (k,j) else A $$ (i,j))"

definition swaprows_mat :: "nat  nat  nat  'a :: semiring_1 mat" where
  "swaprows_mat n k l = mat n n
    (λ (i,j). if k = i  l = j  k = j  l = i  i = j  i  k  i  l then 1 else 0)"

definition mat_addrow_gen :: "('a  'a  'a)  ('a  'a  'a)  'a  nat  nat  'a mat  'a mat" where
  "mat_addrow_gen ad mul a k l A = mat (dim_row A) (dim_col A) 
    (λ (i,j). if k = i then ad (mul a (A $$ (l,j))) (A $$ (i,j)) else A $$ (i,j))"

abbreviation mat_addrow :: "'a :: semiring_1  nat  nat  'a mat  'a mat" ("addrow") where
  "addrow  mat_addrow_gen (+) ((*))"

lemmas mat_addrow_def = mat_addrow_gen_def

definition addrow_mat :: "nat  'a :: semiring_1  nat  nat  'a mat" where
  "addrow_mat n a k l = mat n n (λ (i,j). 
    (if k = i  l = j then (+) a else id) (if i = j then 1 else 0))"

lemma index_mat_multrow[simp]: 
  "i < dim_row A  j < dim_col A  mat_multrow_gen mul k a A $$ (i,j) = (if k = i then mul a (A $$ (i,j)) else A $$ (i,j))"
  "i < dim_row A  j < dim_col A  mat_multrow_gen mul i a A $$ (i,j) = mul a (A $$ (i,j))"
  "i < dim_row A  j < dim_col A  k  i  mat_multrow_gen mul k a A $$ (i,j) = A $$ (i,j)"
  "dim_row (mat_multrow_gen mul k a A) = dim_row A" "dim_col (mat_multrow_gen mul k a A) = dim_col A"
  unfolding mat_multrow_def by auto

lemma index_mat_multrow_mat[simp]:
  "i < n  j < n  multrow_mat n k a $$ (i,j) = (if k = i  k = j then a else if i = j 
     then 1 else 0)"
  "dim_row (multrow_mat n k a) = n" "dim_col (multrow_mat n k a) = n"
  unfolding multrow_mat_def by auto

lemma index_mat_swaprows[simp]: 
  "i < dim_row A  j < dim_col A  swaprows k l A $$ (i,j) = (if k = i then A $$ (l,j) else 
    if l = i then A $$ (k,j) else A $$ (i,j))"
  "dim_row (swaprows k l A) = dim_row A" "dim_col (swaprows k l A) = dim_col A"
  unfolding mat_swaprows_def by auto

lemma index_mat_swaprows_mat[simp]:
  "i < n  j < n  swaprows_mat n k l $$ (i,j) = 
    (if k = i  l = j  k = j  l = i  i = j  i  k  i  l then 1 else 0)"
  "dim_row (swaprows_mat n k l) = n" "dim_col (swaprows_mat n k l) = n"
  unfolding swaprows_mat_def by auto

lemma index_mat_addrow[simp]: 
  "i < dim_row A  j < dim_col A  mat_addrow_gen ad mul a k l A $$ (i,j) = (if k = i then 
    ad (mul a (A $$ (l,j))) (A $$ (i,j)) else A $$ (i,j))"
  "i < dim_row A  j < dim_col A  mat_addrow_gen ad mul a i l A $$ (i,j) = ad (mul a (A $$ (l,j))) (A $$ (i,j))"
  "i < dim_row A  j < dim_col A  k  i  mat_addrow_gen ad mul a k l A $$ (i,j) = A $$(i,j)"
  "dim_row (mat_addrow_gen ad mul a k l A) = dim_row A" "dim_col (mat_addrow_gen ad mul a k l A) = dim_col A"
  unfolding mat_addrow_def by auto

lemma index_mat_addrow_mat[simp]:
  "i < n  j < n  addrow_mat n a k l $$ (i,j) = 
    (if k = i  l = j then (+) a else id) (if i = j then 1 else 0)"
  "dim_row (addrow_mat n a k l) = n" "dim_col (addrow_mat n a k l) = n"
  unfolding addrow_mat_def by auto

lemma multrow_carrier[simp]: "(mat_multrow_gen mul k a A  carrier_mat n nc) = (A  carrier_mat n nc)"
  unfolding carrier_mat_def by fastforce

lemma multrow_mat_carrier[simp]: "multrow_mat n k a  carrier_mat n n"
  unfolding carrier_mat_def by auto

lemma addrow_mat_carrier[simp]: "addrow_mat n a k l  carrier_mat n n"
  unfolding carrier_mat_def by auto

lemma swaprows_mat_carrier[simp]: "swaprows_mat n k l  carrier_mat n n"
  unfolding carrier_mat_def by auto

lemma swaprows_carrier[simp]: "(swaprows k l A  carrier_mat n nc) = (A  carrier_mat n nc)"
  unfolding carrier_mat_def by fastforce

lemma addrow_carrier[simp]: "(mat_addrow_gen ad mul a k l A  carrier_mat n nc) = (A  carrier_mat n nc)"
  unfolding carrier_mat_def by fastforce

lemma row_multrow:  "k  i  i < n  row (multrow_mat n k a) i = unit_vec n i"
  "k < n  row (multrow_mat n k a) k = a v unit_vec n k"
  by (rule eq_vecI, auto)

lemma multrow_mat: assumes A: "A  carrier_mat n nc"
  shows "multrow k a A = multrow_mat n k a * A"
  by (rule eq_matI, insert A, auto simp: row_multrow smult_scalar_prod_distrib[of _ n])

lemma row_addrow: 
  "k  i  i < n  row (addrow_mat n a k l) i = unit_vec n i"
  "k < n  l < n  row (addrow_mat n a k l) k = a v unit_vec n l + unit_vec n k"
  by (rule eq_vecI, auto)

lemma addrow_mat: assumes A: "A  carrier_mat n nc" 
  and l: "l < n"
  shows "addrow a k l A = addrow_mat n a k l * A"
  by (rule eq_matI, insert l A, auto simp: row_addrow 
  add_scalar_prod_distrib[of _ n] smult_scalar_prod_distrib[of _ n])

lemma row_swaprows: 
  "l < n  row (swaprows_mat n l l) l = unit_vec n l"
  "i  k  i  l  i < n  row (swaprows_mat n k l) i = unit_vec n i"
  "k < n  l < n  row (swaprows_mat n k l) l = unit_vec n k"
  "k < n  l < n  row (swaprows_mat n k l) k = unit_vec n l"
  by (rule eq_vecI, auto)

lemma swaprows_mat: assumes A: "A  carrier_mat n nc" and k: "k < n" "l < n"
  shows "swaprows k l A = swaprows_mat n k l * A"
  by (rule eq_matI, insert A k, auto simp: row_swaprows)

lemma swaprows_mat_inv: assumes k: "k < n" and l: "l < n"
  shows "swaprows_mat n k l * swaprows_mat n k l = 1m n"
proof -
  have "swaprows_mat n k l * swaprows_mat n k l = 
    swaprows_mat n k l * (swaprows_mat n k l * 1m n)"
    by (simp add: right_mult_one_mat[of _ n])
  also have "swaprows_mat n k l * 1m n = swaprows k l (1m n)"
    by (rule swaprows_mat[symmetric, OF _ k l, of _ n], simp)
  also have "swaprows_mat n k l *  = swaprows k l "
    by (rule swaprows_mat[symmetric, of _ _ n], insert k l, auto)
  also have " = 1m n"
    by (rule eq_matI, insert k l, auto)
  finally show ?thesis .
qed

lemma swaprows_mat_Unit: assumes k: "k < n" and l: "l < n"
  shows "swaprows_mat n k l  Units (ring_mat TYPE('a :: semiring_1) n b)"
proof -
  interpret m: semiring "ring_mat TYPE('a) n b" by (rule semiring_mat)
  show ?thesis unfolding Units_def
    by (rule, rule conjI[OF _ bexI[of _ "swaprows_mat n k l"]],
    auto simp: ring_mat_def swaprows_mat_inv[OF k l] swaprows_mat_inv[OF l k])
qed

lemma addrow_mat_inv: assumes k: "k < n" and l: "l < n" and neq: "k  l"
  shows "addrow_mat n a k l * addrow_mat n (- (a :: 'a :: comm_ring_1)) k l = 1m n"
proof -
  have "addrow_mat n a k l * addrow_mat n (- a) k l = 
    addrow_mat n a k l * (addrow_mat n (- a) k l * 1m n)"
    by (simp add: right_mult_one_mat[of _ n])
  also have "addrow_mat n (- a) k l * 1m n = addrow (- a) k l (1m n)"
    by (rule addrow_mat[symmetric, of _ _ n], insert k l, auto)
  also have "addrow_mat n a k l *  = addrow a k l "
    by (rule addrow_mat[symmetric, of _ _ n], insert k l, auto)
  also have " = 1m n"
    by (rule eq_matI, insert k l neq, auto, algebra)
  finally show ?thesis .
qed

lemma addrow_mat_Unit: assumes k: "k < n" and l: "l < n" and neq: "k  l"
  shows "addrow_mat n a k l  Units (ring_mat TYPE('a :: comm_ring_1) n b)"
proof -
  interpret m: semiring "ring_mat TYPE('a) n b" by (rule semiring_mat)
  show ?thesis unfolding Units_def
    by (rule, rule conjI[OF _ bexI[of _ "addrow_mat n (- a) k l"]], insert neq,
    auto simp: ring_mat_def addrow_mat_inv[OF k l neq], 
    rule trans[OF _ addrow_mat_inv[OF k l neq, of "- a"]], auto)
qed

lemma multrow_mat_inv: assumes k: "k < n" and a: "(a :: 'a :: division_ring)  0"
  shows "multrow_mat n k a * multrow_mat n k (inverse a) = 1m n"
proof -
  have "multrow_mat n k a * multrow_mat n k (inverse a) = 
    multrow_mat n k a * (multrow_mat n k (inverse a) * 1m n)"
    using k by (simp add: right_mult_one_mat[of _ n])
  also have "multrow_mat n k (inverse a) * 1m n = multrow k (inverse a) (1m n)"
    by (rule multrow_mat[symmetric, of _ _ n], insert k, auto)
  also have "multrow_mat n k a *  = multrow k a "
    by (rule multrow_mat[symmetric, of _ _ n], insert k, auto)
  also have " = 1m n"
    by (rule eq_matI, insert a k a, auto)
  finally show ?thesis .
qed

lemma multrow_mat_Unit: assumes k: "k < n" and a: "(a :: 'a :: division_ring)  0"
  shows "multrow_mat n k a  Units (ring_mat TYPE('a) n b)"
proof -
  from a have ia: "inverse a  0" by auto
  interpret m: semiring "ring_mat TYPE('a) n b" by (rule semiring_mat)
  show ?thesis unfolding Units_def
    by (rule, rule conjI[OF _ bexI[of _ "multrow_mat n k (inverse a)"]], insert a,
    auto simp: ring_mat_def multrow_mat_inv[OF k],
    rule trans[OF _ multrow_mat_inv[OF k ia]], insert a, auto)
qed

subsection ‹Gauss-Jordan Elimination›

fun eliminate_entries_rec where
  "eliminate_entries_rec B i [] = B"
| "eliminate_entries_rec B i ((ai'j,i') # is) = ( 
  eliminate_entries_rec (mat_addrow_gen ((+) :: 'b :: ring_1  'b  'b) (*) ai'j i' i B) i is)"

context
  fixes minus :: "'a  'a  'a"
  and times :: "'a  'a  'a"
begin

definition eliminate_entries_gen :: "(nat  'a)  'a mat  nat  nat  'a mat" where
  "eliminate_entries_gen v A I J = mat (dim_row A) (dim_col A) (λ (i, j).
     if i  I then minus (A $$ (i,j)) (times (v i) (A $$ (I,j))) else A $$ (i,j))" 

lemma dim_eliminate_entries_gen[simp]: "dim_row (eliminate_entries_gen v B i as) = dim_row B"
  "dim_col (eliminate_entries_gen v B i as) = dim_col B"
  unfolding eliminate_entries_gen_def by auto
  
lemma dimc_eliminate_entries_rec[simp]: "dim_col (eliminate_entries_rec B i as) = dim_col B"
  by (induct as arbitrary: B, auto simp: Let_def)

lemma dimr_eliminate_entries_rec[simp]: "dim_row (eliminate_entries_rec B i as) = dim_row B"
  by (induct as arbitrary: B, auto simp: Let_def)

lemma carrier_eliminate_entries: "A  carrier_mat nr nc  eliminate_entries_gen v A i bs  carrier_mat nr nc"
  "B  carrier_mat nr nc  eliminate_entries_rec B i as  carrier_mat nr nc"
  unfolding carrier_mat_def by auto
end

abbreviation "eliminate_entries  eliminate_entries_gen (-) ((*) :: 'a :: ring_1  'a  'a)"

lemma eliminate_entries_convert: 
  assumes jA: "J < dim_col A" and *: "I < dim_row A" "dim_row B = dim_row A" 
  shows "eliminate_entries (λ i. A $$ (i,J)) B I J = 
    eliminate_entries_rec B I (map (λ i. (- A $$ (i, J), i)) (filter (λ i. i  I) [0 ..< dim_row A]))"
proof -
  let ?ais = "λ is. map (λ i. (- A $$ (i, J), i)) (filter (λ i. i  I) is)" 
  define one_go where "one_go = (λ B is. mat (dim_row B) (dim_col B) (λ (i, j).
    if i  I  i  set is then B $$ (i,j) - (A $$ (i,J))  * B $$ (I,j) else B $$ (i,j)))"
  {
    fix "is" :: "nat list" 
    assume "distinct is"     
    from * this have "eliminate_entries_rec B I (?ais is) = one_go B is" 
    proof (induct "is" arbitrary: B)
      case Nil
      show ?case unfolding one_go_def
        by (rule eq_matI, auto)
    next
      case (Cons i "is")
      note I = Cons(2) note dim = Cons(3)      
      note II = Cons(2)[folded dim]
      let ?B = "addrow (- A $$ (i, J)) i I B" 
      from Cons(4) I dim have "I < dim_row A" "dim_row ?B = dim_row A" and dist: "distinct is" by auto
      note IH = Cons(1)[OF this]
      from Cons(4) have i: "i  set is" by auto
      show ?case 
      proof (cases "i = I")
        case False
        hence id: "?ais (i # is) = (- A $$ (i, J), i) # ?ais is" by simp
        show ?thesis unfolding id eliminate_entries_rec.simps IH
          unfolding one_go_def index_mat_addrow
        proof (rule eq_matI, goal_cases)
          case (1 ii jj)
          hence ii: "ii < dim_row B" and jj: "jj < dim_col B" and iiA: "ii < dim_row A" using dim by auto
          show ?case unfolding index_mat[OF ii jj] split
            index_mat_addrow(1)[OF ii jj] index_mat_addrow(1)[OF II jj]
            using i False by auto 
        qed auto
      next
        case True
        hence id: "?ais (i # is) = ?ais is" by simp        
        show ?thesis unfolding id Cons(1)[OF I dim dist]
          unfolding one_go_def True by auto
      qed
    qed
  } note main = this
  show ?thesis
    by (subst main, force, unfold one_go_def eliminate_entries_gen_def, rule eq_matI, 
    insert *, auto)
qed

lemma Unit_prod_eliminate_entries: "i < nr  ( a i'. (a, i')  set is  i' < nr  i'  i)
    P  Units (ring_mat TYPE('a :: comm_ring_1) nr b) .  B nc. B  carrier_mat nr nc  eliminate_entries_rec B i is = P * B" 
proof (induct "is")
  case Nil
  thus ?case by (intro bexI[of _ "1m nr"], auto simp: Units_def ring_mat_def)
next
  case (Cons ai' "is")
  obtain a i' where ai': "ai' = (a,i')" by force
  let ?U = "Units (ring_mat TYPE('a) nr b)"
  interpret m: ring "ring_mat TYPE('a) nr b" by (rule ring_mat)
  from Cons(1)[OF Cons(2-3)] 
  obtain P where P: "P  ?U" and id: " B nc . B  carrier_mat nr nc  
    eliminate_entries_rec B i is = P * B" by force
  let ?Add = "addrow_mat nr a i' i"
  have Add: "?Add  ?U"
    by (rule addrow_mat_Unit, insert Cons ai', auto)
  from m.Units_m_closed[OF P Add] have PI: "P * ?Add  ?U" unfolding ring_mat_def by simp
  from m.Units_closed[OF P] have P: "P  carrier_mat nr nr" unfolding ring_mat_def by simp
  show ?case
  proof (rule bexI[OF _ PI], intro allI impI)
    fix B :: "'a mat" and nc
    assume BB: "B  carrier_mat nr nc"
    let ?B = "addrow a i' i B"
    from BB have B: "?B  carrier_mat nr nc" by simp
    from id[OF B] have id: "eliminate_entries_rec ?B i is = P * ?B" .
    have id2: "eliminate_entries_rec B i (ai' # is) = eliminate_entries_rec ?B i is" unfolding ai' by simp
    show "eliminate_entries_rec B i (ai' # is) = P * ?Add * B"
      unfolding id2 id unfolding addrow_mat[OF BB Cons(2)]
      by (rule assoc_mult_mat[symmetric, OF P _ BB], auto)
  qed
qed

function gauss_jordan_main :: "'a :: field mat  'a mat  nat  nat  'a mat × 'a mat" where
  "gauss_jordan_main A B i j = (let nr = dim_row A; nc = dim_col A in
    if i < nr  j < nc then let aij = A $$ (i,j) in if aij = 0 then
      (case [ i' . i' <- [Suc i ..< nr],  A $$ (i',j)  0] 
        of []  gauss_jordan_main A B i (Suc j)
         | (i' # _)  gauss_jordan_main (swaprows i i' A) (swaprows i i' B) i j)
      else if aij = 1 then let 
        v = (λ i. A $$ (i,j)) in
        gauss_jordan_main 
        (eliminate_entries v A i j) (eliminate_entries v B i j) (Suc i) (Suc j)
      else let iaij = inverse aij in gauss_jordan_main (multrow i iaij A) (multrow i iaij B) i j
    else (A,B))"
  by pat_completeness auto

termination
proof -
  let ?R = "measures [λ (A :: 'a :: field mat,B,i,j). dim_col A - j, 
    λ (A,B,i,j). if A $$ (i,j) = 0 then 2 else if A $$ (i,j) = 1 then 0 else 1]"
  show ?thesis
  proof
    show "wf ?R" by auto
  next
    fix A B :: "'a mat" and i j nr nc a i' "is"
    assume *: "nr = dim_row A" "nc = dim_col A" "i < nr  j < nc" "a = A $$ (i, j)" "a = 0"
      and ne: "[ i' . i' <- [Suc i ..< nr],  A $$ (i',j)  0] = i' # is"
    from ne have "i'  set ([ i' . i' <- [Suc i ..< nr],  A $$ (i',j)  0])" by auto
    with *
    show "((swaprows i i' A, swaprows i i' B, i, j), A, B, i, j)  ?R" by auto
  qed auto
qed

declare gauss_jordan_main.simps[simp del]

definition "gauss_jordan A B  gauss_jordan_main A B 0 0"

lemma gauss_jordan_transform: assumes A: "A  carrier_mat nr nc" and B: "B  carrier_mat nr nc'"
  and res: "gauss_jordan (A :: 'a :: field mat) B = (A',B')"
  shows " P  Units (ring_mat TYPE('a) nr b). A' = P * A  B' = P * B"
proof -
  let ?U = "Units (ring_mat TYPE('a) nr b)"
  interpret m: ring "ring_mat TYPE('a) nr b" by (rule ring_mat)
  {
    fix i j :: nat
    assume "gauss_jordan_main A B i j = (A',B')"
    with A B
    have " P  ?U. A' = P * A  B' = P * B"
    proof (induction A B i j rule: gauss_jordan_main.induct)
      case (1 A B i j)
      note A = 1(5)
      hence dim: "dim_row A = nr" "dim_col A = nc" by auto
      note B = 1(6)
      hence dimB: "dim_row B = nr" "dim_col B = nc'" by auto
      note IH = 1(1-4)[OF dim[symmetric]]
      note res = 1(7)
      note simp = gauss_jordan_main.simps[of A B i j] Let_def
      let ?g = "gauss_jordan_main A B i j"
      show ?case 
      proof (cases "i < nr  j < nc")
        case False
        with res have res: "A' = A" "B' = B" unfolding simp dim by auto
        show ?thesis unfolding res
          by (rule bexI[of _ "1m nr"], insert A B, auto simp: Units_def ring_mat_def)
      next
        case True note valid = this
        note IH = IH[OF valid refl]
        show ?thesis 
        proof (cases "A $$ (i,j) = 0")
          case False note nZ = this
          note IH = IH(3-4)[OF nZ]
          show ?thesis
          proof (cases "A $$ (i,j) = 1")
            case False note nO = this
            let ?inv = "inverse (A $$ (i,j))"
            from nO nZ valid res 
            have "gauss_jordan_main (multrow i ?inv A) (multrow i ?inv B) i j = (A',B')"
              unfolding simp dim by simp
            note IH = IH(2)[OF nO refl, unfolded multrow_carrier, OF A B this]
            from IH obtain P where P: "P  ?U" and
              id: "A' = P * multrow i ?inv A" "B' = P * multrow i ?inv B" by blast
            let ?Inv = "multrow_mat nr i ?inv"
            from nZ valid have "i < nr" "?inv  0" by auto
            from multrow_mat_Unit[OF this]
            have Inv: "?Inv  ?U" .
            from m.Units_m_closed[OF P Inv] have PI: "P * ?Inv  ?U" unfolding ring_mat_def by simp
            from m.Units_closed[OF P] have P: "P  carrier_mat nr nr" unfolding ring_mat_def by simp
            show ?thesis unfolding id unfolding multrow_mat[OF A] multrow_mat[OF B]
              by (rule bexI[OF _ PI], intro conjI, 
                rule assoc_mult_mat[symmetric, OF P _ A], simp, 
                rule assoc_mult_mat[symmetric, OF P _ B], simp)
          next
            case True note O = this
            let ?is = "filter (λ i'. i'  i) [0 ..< nr]" 
            let ?ais = "map (λ i'. (-A $$ (i',j), i')) ?is" 
            let ?E = "λ B. eliminate_entries (λ i. A $$ (i,j)) B i j"
            let ?EE = "λ B. eliminate_entries_rec B i ?ais"
            let ?A = "?E A"
            let ?B = "?E B"
            let ?AA = "?EE A"
            let ?BB = "?EE B"
            from O nZ valid res have "gauss_jordan_main ?A ?B (Suc i) (Suc j) = (A',B')"
              unfolding simp dim by simp
            note IH = IH(1)[OF O refl carrier_eliminate_entries(1)[OF A] carrier_eliminate_entries(1)[OF B] this]
            from IH obtain P where P: "P  ?U" and id: "A' = P * ?A" "B' = P * ?B" by blast
            have *: "j < dim_col A" "i < dim_row A" by (auto simp add: dim valid)
            have "P?U.  B nc. B  carrier_mat nr nc  ?EE B = P * B"
              by (rule Unit_prod_eliminate_entries, insert valid, auto)
            then obtain Q where Q: "Q  ?U" and QQ: " B nc. B  carrier_mat nr nc  ?EE B = Q * B" by auto
            {
              fix B :: "'a mat" and nc
              assume B: "B  carrier_mat nr nc" 
              with dim have "dim_row B = dim_row A" by auto
              from eliminate_entries_convert[OF * this]
              have "?E B = ?EE B" using dim by simp
              also have " = Q * B" using QQ[OF B] by simp
              finally have "?E B = Q * B" .
            } note QQ = this              
            have id3: "?A = Q * A" by (rule QQ[OF A])
            have id4: "?B = Q * B" by (rule QQ[OF B])
            from m.Units_closed[OF P] have Pc: "P  carrier_mat nr nr" unfolding ring_mat_def by simp
            from m.Units_closed[OF Q] have Qc: "Q  carrier_mat nr nr" unfolding ring_mat_def by simp
            from m.Units_m_closed[OF P Q] have PQ: "P * Q  ?U" unfolding ring_mat_def by simp
            show ?thesis unfolding id unfolding id3 id4 
              by (rule bexI[OF _ PQ], rule conjI, 
              rule assoc_mult_mat[symmetric, OF Pc Qc A],
              rule assoc_mult_mat[symmetric, OF Pc Qc B])
          qed
        next
          case True note Z = this
          note IH = IH(1-2)[OF Z]
          let ?is = "[ i' . i' <- [Suc i ..< nr],  A $$ (i',j)  0]"
          show ?thesis
          proof (cases ?is)
            case Nil 
            from Z valid res have id: "gauss_jordan_main A B i (Suc j) = (A',B')" unfolding simp dim Nil by simp
            from IH(1)[OF Nil A B this] show ?thesis unfolding id .
          next
            case (Cons i' iis)
            from Z valid res have "gauss_jordan_main (swaprows i i' A) (swaprows i i' B) i j = (A',B')" 
              unfolding simp dim Cons by simp
            from IH(2)[OF Cons, unfolded swaprows_carrier, OF A B this]
            obtain P where P: "P  ?U" and
              id: "A' = P * swaprows i i' A" "B' = P * swaprows i i' B" by blast
            let ?Swap = "swaprows_mat nr i i'"
            from Cons have "i'  set ?is" by auto
            with valid have i': "i < nr" "i' < nr" by auto
            from swaprows_mat_Unit[OF this] have Swap: "?Swap  ?U" .
            from m.Units_m_closed[OF P Swap] have PI: "P * ?Swap  ?U" unfolding ring_mat_def by simp
            from m.Units_closed[OF P] have P: "P  carrier_mat nr nr" unfolding ring_mat_def by simp
            show ?thesis unfolding id swaprows_mat[OF A i'] swaprows_mat[OF B i']
              by (rule bexI[OF _ PI], rule conjI, 
              rule assoc_mult_mat[symmetric, OF P _ A], simp,
              rule assoc_mult_mat[symmetric, OF P _ B], simp)
          qed
        qed
      qed
    qed
  }
  from this[of 0 0, folded gauss_jordan_def, OF res] show ?thesis .
qed

lemma gauss_jordan_carrier: assumes A: "(A :: 'a :: field mat)  carrier_mat nr nc"
  and B: "B  carrier_mat nr nc'" 
  and res: "gauss_jordan A B = (A',B')"
  shows "A'  carrier_mat nr nc" "B'  carrier_mat nr nc'"
proof -
  from gauss_jordan_transform[OF A B res, of undefined]
  obtain P where P: "P  Units (ring_mat TYPE('a) nr undefined)"
    and id: "A' = P * A" "B' = P * B" by auto
  from P have P: "P  carrier_mat nr nr" unfolding Units_def ring_mat_def by auto
  show "A'  carrier_mat nr nc" "B'  carrier_mat nr nc'" unfolding id
    using P A B by auto
qed


definition pivot_fun :: "'a :: {zero,one} mat  (nat  nat)  nat  bool" where
  "pivot_fun A f nc  let nr = dim_row A in 
    ( i < nr. f i  nc  
      (f i < nc  A $$ (i, f i) = 1  ( i' < nr. i'  i  A $$ (i',f i) = 0))  
      ( j < f i. A $$ (i, j) = 0) 
      (Suc i < nr  f (Suc i) > f i  f (Suc i) = nc))"

lemma pivot_funI: assumes d: "dim_row A = nr"
  and *: " i. i < nr  f i  nc"
      " i j. i < nr  j < f i  A $$ (i,j) = 0"
      " i. i < nr  Suc i < nr  f (Suc i) > f i  f (Suc i) = nc"
      " i. i < nr  f i < nc  A $$ (i, f i) = 1"
      " i i'. i < nr  f i < nc  i' < nr  i'  i  A $$ (i',f i) = 0"
  shows "pivot_fun A f nc"
  unfolding pivot_fun_def Let_def d using * by blast

lemma pivot_funD: assumes d: "dim_row A = nr"
  and p: "pivot_fun A f nc"
  shows " i. i < nr  f i  nc"
      " i j. i < nr  j < f i  A $$ (i,j) = 0"
      " i. i < nr  Suc i < nr  f (Suc i) > f i  f (Suc i) = nc"
      " i. i < nr  f i < nc  A $$ (i, f i) = 1"
      " i i'. i < nr  f i < nc  i' < nr  i'  i  A $$ (i',f i) = 0"
  using p unfolding pivot_fun_def Let_def d by blast+

lemma pivot_fun_multrow: assumes p: "pivot_fun A f jj"
  and d: "dim_row A = nr" "dim_col A = nc"
  and fi: "f i0 = jj"
  and jj: "jj  nc"
  shows "pivot_fun (multrow i0 a A) f jj"
proof -
  note p = pivot_funD[OF d(1) p]
  let ?A = "multrow i0 a A"
  have "dim_row ?A = nr" using d by simp
  thus ?thesis
  proof (rule pivot_funI)
    fix i
    assume i: "i < nr"
    note p = p[OF i]
    show "f i  jj" by fact
    show "Suc i < nr  f i < f (Suc i)  f (Suc i) = jj" by fact
    {
      fix i'
      assume *: "f i < jj" "i' < nr" "i'  i" 
      from p(5)[OF this]
      show "?A $$ (i', f i) = 0"
        by (subst index_mat_multrow(1), insert * d jj, auto)
    }
    {
      assume *: "f i < jj"
      from p(4)[OF this] have A: "A $$ (i, f i) = 1" by auto
      show "?A $$ (i, f i) = 1"
        by (subst index_mat_multrow(1), insert * d i A jj fi, auto)
    }
    {
      fix j
      assume j: "j < f i"
      from p(2)[OF j]
      show "?A $$ (i, j) = 0"
        by (subst index_mat_multrow(1), insert j d i p jj fi, auto)
    }
  qed
qed

lemma pivot_fun_swaprows: assumes p: "pivot_fun A f jj"
  and d: "dim_row A = nr" "dim_col A = nc"
  and flk: "f l = jj" "f k = jj"
  and nr: "l < nr" "k < nr"
  and jj: "jj  nc"
  shows "pivot_fun (swaprows l k A) f jj"
proof -
  note pivot = pivot_funD[OF d(1) p]
  let ?A = "swaprows l k A"
  have "dim_row ?A = nr" using d by simp
  thus ?thesis
  proof (rule pivot_funI)
    fix i
    assume i: "i < nr"
    note p = pivot[OF i]
    show "f i  jj" by fact
    show "Suc i < nr  f i < f (Suc i)  f (Suc i) = jj" by fact
    {
      fix i'
      assume *: "f i < jj" "i' < nr" "i'  i" 
      from *(1) flk have diff: "l  i" "k  i" by auto
      from p(5)[OF *] p(5)[OF *(1) nr(1) diff(1)] p(5)[OF *(1) nr(2) diff(2)]
      show "?A $$ (i', f i) = 0"  
        by (subst index_mat_swaprows(1), insert * d jj, auto)
    }
    {
      assume *: "f i < jj"
      from p(4)[OF this] have A: "A $$ (i, f i) = 1" by auto
      show "?A $$ (i, f i) = 1"
        by (subst index_mat_swaprows(1), insert * d i A jj flk, auto)
    }
    {
      fix j
      assume j: "j < f i"
      with p(1) flk have le: "j < f l" "j < f k" by auto
      from p(2)[OF j] pivot(2)[OF nr(1) le(1)] pivot(2)[OF nr(2) le(2)]
      show "?A $$ (i, j) = 0" 
        by (subst index_mat_swaprows(1), insert j d i p jj, auto) 
    }
  qed
qed

lemma pivot_fun_eliminate_entries: assumes p: "pivot_fun A f jj"
  and d: "dim_row A = nr" "dim_col A = nc"
  and fl: "f l = jj"
  and nr: "l < nr"
  and jj: "jj  nc"
shows "pivot_fun (eliminate_entries vs A l j) f jj" 
proof -
  note pD = pivot_funD[OF d(1) p]
  {
    fix i j
    assume *: "i < nr" "j < f i"
    from pD(1)[OF this(1)] this(2) jj have j: "j < nc" by auto
    from pD nr fl * j have "A $$ (l, j) = 0" by (meson less_le_trans)
    note j this
  } note hint = this
  show ?thesis by (rule pivot_funI, insert fl nr jj pD, auto simp: eliminate_entries_gen_def d hint)
qed
    
definition row_echelon_form :: "'a :: {zero,one} mat  bool" where
  "row_echelon_form A   f. pivot_fun A f (dim_col A)"

lemma pivot_fun_init: "pivot_fun A (λ _. 0) 0"
  by (rule pivot_funI, auto)

lemma gauss_jordan_main_row_echelon: 
  assumes 
    "A  carrier_mat nr nc"
    "gauss_jordan_main A B i j = (A',B')"
    "pivot_fun A f j" 
    " i'. i' < i  f i' < j" " i'. i'  i  f i' = j"
    "i  nr" "j  nc"
  shows "row_echelon_form A'"
proof -
  fix b
  interpret m: ring "ring_mat TYPE('a) nr b" by (rule ring_mat)
  show ?thesis
    using assms
  proof (induct A B i j arbitrary: f rule: gauss_jordan_main.induct)
    case (1 A B i j f)
    note A = 1(5)
    hence dim: "dim_row A = nr" "dim_col A = nc" by auto
    note res = 1(6)
    note pivot = 1(7)
    note f = 1(8-9)
    note ij = 1(10-11)
    note IH = 1(1-4)[OF dim[symmetric]]
    note simp = gauss_jordan_main.simps[of A B i j] Let_def
    let ?g = "gauss_jordan_main A B i j"
    show ?case 
    proof (cases "i < nr  j < nc")
      case False note nij = this
      with res have id: "A' = A" unfolding simp dim by auto
      have "pivot_fun A f nc"
      proof (cases "j  nc")
        case True
        with ij have j: "j = nc" by auto
        with pivot show "pivot_fun A f nc" by simp
      next
        case False
        hence j: "j < nc" by simp
        from False nij ij have i: "i = nr" by auto
        note f = f[unfolded i]
        note p = pivot_funD[OF dim(1) pivot]
        show ?thesis
        proof (rule pivot_funI[OF dim(1)])
          fix i
          assume i: "i < nr"
          note p = p[OF i]
          from p(1) j show "f i  nc" by simp
          from f(1)[OF i] have fij: "f i < j" .
          from p(4)[OF fij] show "A $$ (i, f i) = 1" .
          from p(5)[OF fij] show " i'. i' < nr  i'  i  A $$ (i', f i) = 0" .
          show " j. j < f i  A $$ (i, j) = 0" by (rule p(2))
          assume "Suc i < nr"
          with p(3)[OF this] f
          show "f i < f (Suc i)  f (Suc i) = nc" by auto
        qed          
      qed
      thus ?thesis using pivot unfolding id row_echelon_form_def dim by blast
    next
      case True note valid = this
      hence sij: "Suc i  nr" "Suc j  nc" by auto
      note IH = IH[OF valid refl]
      show ?thesis 
      proof (cases "A $$ (i,j) = 0")
        case False note nZ = this
        note IH = IH(3-4)[OF nZ]
        show ?thesis
        proof (cases "A $$ (i,j) = 1")
          case False note nO = this
          let ?inv = "inverse (A $$ (i,j))"
          let ?A = "multrow i ?inv A"
          from nO nZ valid res have id: "gauss_jordan_main (multrow i ?inv A) (multrow i ?inv B) i j = (A', B')"
            unfolding simp dim by simp
          have "pivot_fun ?A f j"
            by (rule pivot_fun_multrow[OF pivot dim f(2) ij(2)], auto)
          note IH = IH(2)[OF nO refl, unfolded multrow_carrier, OF A id this f ij]
          show ?thesis unfolding id using IH .
        next
          case True note O = this
          let ?E = "λ B. eliminate_entries (λ i. A $$ (i,j)) B i j" 
          let ?A = "?E A"
          let ?B = "?E B"
          define E where "E = ?A"
          let ?f = "λ i'. if i' = i then j else if i' > i then Suc j else f i'"
          have pivot: "pivot_fun E f j" unfolding E_def          
            by (rule pivot_fun_eliminate_entries[OF pivot dim f(2)], insert valid, auto)
          {
            fix i'
            assume i': "i' < nr"
            have "E $$ (i', j) = (if i' = i then 1 else 0)"
              unfolding E_def eliminate_entries_gen_def using dim O i' valid by auto
          } note Ej = this
          have E: "E  carrier_mat nr nc" unfolding E_def by (rule carrier_eliminate_entries[OF A])
          hence dimE: "dim_row E = nr" "dim_col E = nc" by auto
          note pivot = pivot_funD[OF dimE(1) pivot]
          have "pivot_fun E ?f (Suc j)"
          proof (rule pivot_funI[OF dimE(1)])
            fix ii
            assume ii: "ii < nr"
            note p = pivot[OF ii]
            show "?f ii  Suc j" using p(1) by simp
            {
              fix jj
              assume jj: "jj < ?f ii"
              show "E $$ (ii,jj) = 0"
              proof (cases "ii < i")
                case True
                with jj have "jj < f ii" by auto
                from p(2)[OF this] show ?thesis .
              next
                case False note ge = this
                with f have fiij: "f ii = j" by simp 
                show ?thesis
                proof (cases "i < ii")
                  case True
                  with jj have jj: "jj  j" by auto
                  show ?thesis
                  proof (cases "jj < j")
                    case True
                    with p(2)[of jj] fiij show ?thesis by auto
                  next
                    case False
                    with jj have jj: "jj = j" by auto
                    with Ej[OF ii] True show ?thesis by auto
                  qed
                next
                  case False
                  with ge have ii: "ii = i" by simp
                  with jj have jj: "jj < j" by simp
                  from p(2)[of jj] ii jj fiij show ?thesis by auto
                qed
              qed
            }
            {
              assume "Suc ii < nr"
              from p(3)[OF this] f
              show "?f (Suc ii) > ?f ii  ?f (Suc ii) = Suc j" by auto
            }
            {
              assume fii: "?f ii < Suc j"
              show "E $$ (ii, ?f ii) = 1"
              proof (cases "ii = i")
                case True
                with Ej[of i] valid show ?thesis by auto
              next
                case False
                with fii have ii: "ii < i" by (auto split: if_splits)
                from f(1)[OF this] have "f ii < j" by auto
                from p(4)[OF this] ii show ?thesis by simp
              qed
            }
            {
               fix i'
               assume *: "?f ii < Suc j" "i' < nr" "i'  ii"
               show "E $$ (i', ?f ii) = 0"
               proof (cases "ii = i")
                 case False
                 with *(1) have iii: "ii < i" by (auto split: if_splits)
                 from f(1)[OF this] have "f ii < j" by auto
                 from p(5)[OF this *(2-3)] show ?thesis using iii by simp
               next
                 case True
                 with *(2-3) Ej[of i'] show ?thesis by auto
               qed
            }
          qed 
          note IH = IH(1)[OF O refl, folded E_def, OF E _ this _ _ sij]     
          from O nZ valid res have "gauss_jordan_main E ?B (Suc i) (Suc j) = (A', B')"
            unfolding E_def simp dim by simp
          note IH = IH[OF this]
          show ?thesis  
          proof (rule IH)
            fix i'
            assume "i' < Suc i"
            thus "?f i' < Suc j" using f[of i'] by (cases "i' < i", auto)
          qed auto
        qed
      next
        case True note Z = this
        note IH = IH(1-2)[OF Z]
        let ?is = "[ i' . i' <- [Suc i ..< nr],  A $$ (i',j)  0]"
        show ?thesis
        proof (cases ?is)
          case Nil
          {
            fix i'
            assume "i  i'" and "i' < nr"
            hence "i' = i  i'  {Suc i ..< nr}" by auto
            from this arg_cong[OF Nil, of set] Z have "A $$ (i',j) = 0" by auto
          } note zero = this
          let ?f = "λ i'. if i' < i then f i' else Suc j"
          note p = pivot_funD[OF dim(1) pivot]
          have "pivot_fun A ?f (Suc j)"
          proof (rule pivot_funI[OF dim(1)])
            fix ii
            assume ii: "ii < nr"
            note p = p[OF this]
            show "?f ii  Suc j" using p(1) by simp
            {
              fix jj
              assume jj: "jj < ?f ii"              
              show "A $$ (ii,jj) = 0"
              proof (cases "ii < i")
                case True
                with jj have "jj < f ii" by auto
                from p(2)[OF this] show ?thesis .
              next
                case False
                with jj have ii': "ii  i" and jjj: "jj  j" by auto
                from zero[OF ii' ii] have Az: "A $$ (ii,j) = 0" .
                show ?thesis
                proof (cases "jj < j")
                  case False
                  with jjj have "jj = j" by auto
                  with Az show ?thesis by simp
                next
                  case True
                  show ?thesis
                    by (rule p(2), insert True False f, auto)
                qed
              qed
            }
            {
              assume sii: "Suc ii < nr"
              show "?f ii < ?f (Suc ii)  ?f (Suc ii) = Suc j"
                using p(3)[OF sii] f by auto
            }
            {
              assume fii: "?f ii < Suc j"
              thus "A $$ (ii, ?f ii) = 1"
                using p(4) f by (cases "ii < i", auto)
              fix i'
              assume "i' < nr" "i'  ii"
              from p(5)[OF _ this] f fii
              show "A $$ (i', ?f ii) = 0" 
                by (cases "ii < i", auto)
            }
          qed
          note IH = IH(1)[OF Nil A _ this _ _ ij(1) sij(2)]
          from Z valid res have "gauss_jordan_main A B i (Suc j) = (A',B')" unfolding simp dim Nil by simp
          note IH = IH[OF this]
          show ?thesis  
            by (rule IH, insert f, force+)
        next
          case (Cons i' iis)
          from arg_cong[OF this, of set] have i': "i'  Suc i" "i' < nr" by auto
          from f[of i] f[of "i'"] i' have fij: "f i = j" "f i' = j" by auto 
          let ?A = "swaprows i i' A"
          let ?B = "swaprows i i' B"
          have "pivot_fun ?A f j"
            by (rule pivot_fun_swaprows[OF pivot dim fij], insert i' ij, auto)
          note IH = IH(2)[OF Cons, unfolded swaprows_carrier, OF A _ this f ij]
          from Z valid res have id: "gauss_jordan_main ?A ?B i j = (A', B')" unfolding simp dim Cons by simp
          note IH = IH[OF this]
          show ?thesis using IH .
        qed
      qed
    qed
  qed
qed

lemma gauss_jordan_row_echelon: 
  assumes A: "A  carrier_mat nr nc" 
  and res: "gauss_jordan A B = (A', B')"
  shows "row_echelon_form A'"
  by (rule gauss_jordan_main_row_echelon[OF A res[unfolded gauss_jordan_def] pivot_fun_init], auto)

lemma pivot_bound: assumes dim: "dim_row A = nr"
  and pivot: "pivot_fun A f n"
  shows "i + j < nr  f (i + j) = n  f (i + j)  j + f i"
proof (induct j)
  case (Suc j)
  hence IH: "f (i + j) = n  j + f i  f (i + j)" 
    and lt: "i + j < nr" "Suc (i + j) < nr" by auto
  note p = pivot_funD[OF dim pivot]
  from p(3)[OF lt] IH p(1)[OF lt(2)] show ?case by auto
qed simp

context
  fixes zero :: 'a
  and A :: "'a mat"
  and nr nc :: nat
begin
function pivot_positions_main_gen :: "nat  nat  (nat × nat) list" where
  "pivot_positions_main_gen i j = (
     if i < nr then
       if j < nc then 
         if A $$ (i,j) = zero then 
           pivot_positions_main_gen i (Suc j)
         else (i,j) # pivot_positions_main_gen (Suc i) (Suc j)
       else []
     else [])" by pat_completeness auto

termination by (relation "measures [(λ (i,j). Suc nr - i), (λ (i,j). Suc nc - j)]", auto)

declare pivot_positions_main_gen.simps[simp del]
end

context
  fixes A :: "'a :: semiring_1 mat"
  and nr nc :: nat
begin

abbreviation "pivot_positions_main  pivot_positions_main_gen (0 :: 'a) A nr nc"

lemma pivot_positions_main: assumes A: "A  carrier_mat nr nc"
  and pivot: "pivot_fun A f nc"
  shows "j  f i  i  nr  
    set (pivot_positions_main i j) = {(i', f i') | i'. i  i'  i' < nr} - UNIV × {nc}
     distinct (map snd (pivot_positions_main i j))
     distinct (map fst (pivot_positions_main i j))"
proof (induct i j rule: pivot_positions_main_gen.induct[of nr nc A 0])
  case (1 i j)
  let ?a = "A $$ (i,j)"
  let ?pivot = "λ i j. pivot_positions_main i j"
  let ?set = "λ i. {(i',f i') | i'. i  i'  i' < nr}"
  let ?s = "?set i"
  let ?set = "λ i. {(i',f i') | i'. i  i'  i' < nr}"
  let ?s = "?set i"
  let ?p = "?pivot i j"
  from A have dA: "dim_row A = nr" by simp
  note [simp] = pivot_positions_main_gen.simps[of 0 A nr nc i j]
  show ?case
  proof (cases "i < nr")
    case True note i = this
    note IH = 1(1-2)[OF True]
    have jfi: "j  f i" using 1(3) i by auto
    note pivotB = pivot_bound[OF dA pivot]
    note pivot' = pivot_funD[OF dA pivot]
    note pivot = pivot'[OF True]
    have id1: "[i ..< nr] = i # [Suc i ..< nr]" using i by (rule upt_conv_Cons)
    show ?thesis
    proof (cases "j < nc")
      case True note j = this
      note IH = IH(1-2)[OF True]
      show ?thesis
      proof (cases "?a = 0")
        case True note a = this
        from i j a have p: "?p = ?pivot i (Suc j)" by simp
        {
          assume "f i = j"
          with pivot(4) j have "?a = 1" by simp
          with a have False by simp
        }
        with jfi have "Suc j  f i  i  nr" by fastforce
        note IH = IH(1)[OF True this]
        thus ?thesis unfolding p .
      next
        case False note a = this
        from i j a have p: "?p = (i,j) # ?pivot (Suc i) (Suc j)" by simp
        from pivot(2)[of j] jfi a have jfi: "j = f i" by force
        from pivotB[of i "Suc 0"] jfi have "Suc j  f (Suc i)  nr  Suc i" 
          using Suc_le_eq j leI by auto
        note IH = IH(2)[OF False this]
        {
          fix i'
          assume *: "f i = f i'" "Suc i  i'" "i' < nr" 
          hence "i + (i' - i) = i'" by auto
          from pivotB[of i "i' - i", unfolded this] * jfi j have False by auto
        } note distinct = this
        have id2: "?s = insert (i,j) (?set (Suc i))" using i jfi not_less_eq_eq
          by fastforce
        show ?thesis using IH j jfi i unfolding p id1 id2 by (auto intro: distinct)
      qed
    next
      case False note j = this
      from pivot(1) j jfi have *: "f i = nc" "nc = j" by auto
      from i j have p: "?p = []" by simp
      from pivotB[of i "Suc 0"] * have "j  f (Suc i)  nr  Suc i" by auto
      {
        fix i'
        assume **: "i  i'" "i' < nr" 
        hence "i + (i' - i) = i'" by auto
        from pivotB[of i "i' - i", unfolded this] ** * have "nc  f i'" by auto
        with pivot'(1)[OF i' < nr] have "f i' = nc" by auto
      }
      thus ?thesis using IH unfolding p id1 by auto
    qed
  qed auto
qed
end

lemma pivot_fun_zero_row_iff: assumes pivot: "pivot_fun (A :: 'a :: semiring_1 mat) f nc"
  and A: "A  carrier_mat nr nc"
  and i: "i < nr"
  shows "f i = nc  row A i = 0v nc"
proof -
  from A have dim: "dim_row A = nr" by auto
  note pivot = pivot_funD[OF dim pivot i]
  {
    assume "f i = nc"
    from pivot(2)[unfolded this]
    have "row A i = 0v nc"
      by (intro eq_vecI, insert A, auto simp: row_def)
  }
  moreover
  {
    assume row: "row A i = 0v nc"
    assume "f i  nc"
    with pivot(1) have "f i < nc" by auto
    with pivot(4)[OF this] i A arg_cong[OF row, of "λ v. v $ f i"] have False by auto
  }
  ultimately show ?thesis by auto
qed

definition pivot_positions_gen :: "'a  'a mat  (nat × nat) list" where
  "pivot_positions_gen zer A  pivot_positions_main_gen zer A (dim_row A) (dim_col A) 0 0"

abbreviation pivot_positions :: "'a :: semiring_1 mat  (nat × nat) list" where
  "pivot_positions  pivot_positions_gen 0"

lemmas pivot_positions_def = pivot_positions_gen_def

lemma pivot_positions: assumes A: "A  carrier_mat nr nc"
  and pivot: "pivot_fun A f nc"
  shows 
    "set (pivot_positions A) = {(i, f i) | i. i < nr  f i  nc}"
    "distinct (map fst (pivot_positions A))"
    "distinct (map snd (pivot_positions A))"
    "length (pivot_positions A) = card { i. i < nr  row A i  0v nc}"
proof -
  from A have dim: "dim_row A = nr" by auto
  let ?pp = "pivot_positions A"
  show id: "set ?pp = {(i, f i) | i. i < nr  f i  nc}"
    and dist: "distinct (map fst ?pp)"
    and "distinct (map snd ?pp)"  
  using pivot_positions_main[OF A pivot, of 0 0] A
  unfolding pivot_positions_def by auto
  have "length ?pp = length (map fst ?pp)" by simp
  also have " = card (fst ` set ?pp)" using distinct_card[OF dist] by simp
  also have "fst ` set ?pp = { i. i < nr  f i  nc}" unfolding id by force
  also have " = { i. i < nr  row A i  0v nc}"
    using pivot_fun_zero_row_iff[OF pivot A] by auto
  finally show "length ?pp = card {i. i < nr  row A i  0v nc}" .
qed

context 
  fixes uminus :: "'a  'a"
  and zero :: 'a
  and one :: 'a
begin
definition non_pivot_base_gen :: "'a mat  (nat × nat)list  nat  'a vec" where
  "non_pivot_base_gen A pivots  let nr = dim_row A; nc = dim_col A; 
     invers = map_of (map prod.swap pivots)
     in (λ qj. vec nc (λ i. 
     if i = qj then one else (case invers i of Some j => uminus (A $$ (j,qj)) | None  zero)))"

definition find_base_vectors_gen :: "'a mat  'a vec list" where
  "find_base_vectors_gen A  
    let 
      pp = pivot_positions_gen zero A;     
      cands = filter (λ j. j  set (map snd pp)) [0 ..< dim_col A]
    in map (non_pivot_base_gen A pp) cands"
end

abbreviation "non_pivot_base  non_pivot_base_gen uminus 0 (1 :: 'a :: comm_ring_1)"
abbreviation "find_base_vectors  find_base_vectors_gen uminus 0 (1 :: 'a :: comm_ring_1)"

lemmas non_pivot_base_def = non_pivot_base_gen_def
lemmas find_base_vectors_def = find_base_vectors_gen_def

text ‹The soundness of @{const find_base_vectors} is proven in theory Matrix-Kern,
  where it is shown that @{const find_base_vectors} is a basis of the kern of $A$.›

definition find_base_vector :: "'a :: comm_ring_1 mat  'a vec" where
  "find_base_vector A  
    let 
      pp = pivot_positions A;     
      cands = filter (λ j. j  set (map snd pp)) [0 ..< dim_col A]
    in non_pivot_base A pp (hd cands)"

context
  fixes A :: "'a :: field mat" and nr nc :: nat and p :: "nat  nat"
  assumes ref: "row_echelon_form A"
  and A: "A  carrier_mat nr nc"
begin

lemma non_pivot_base:
  defines pp: "pp  pivot_positions A"
  assumes qj: "qj < nc" "qj  snd ` set pp" 
  shows "non_pivot_base A pp qj  carrier_vec nc"
    "non_pivot_base A pp qj $ qj = 1"
    "A *v non_pivot_base A pp qj = 0v nr"
    " qj'. qj' < nc  qj'  snd ` set pp  qj  qj'  non_pivot_base A pp qj $ qj' = 0"
proof -
  from A have dim: "dim_row A = nr" "dim_col A = nc" by auto
  from ref[unfolded row_echelon_form_def] obtain p 
  where pivot: "pivot_fun A p nc" using dim by auto
  note pivot' = pivot_funD[OF dim(1) pivot]
  note pp = pivot_positions[OF A pivot, folded pp]
  let ?p = "λ i. i < nr  p i = nc  i = nr"
  let ?spp = "map prod.swap pp"
  let ?map = "map_of ?spp"
  define I where "I = (λ i. case map_of (map prod.swap pp) i of Some j  - A $$ (j,qj) | None  0)"
  have d: "non_pivot_base A pp qj = vec nc (λ i. if i = qj then 1 else I i)"
    unfolding non_pivot_base_def Let_def dim I_def ..
  from pp have dist: "distinct (map fst ?spp)" 
    unfolding map_map o_def prod.swap_def by auto
  let ?r = "set (map snd pp)"
  have r: "?r = p ` {0 ..< nr} - {nc}" unfolding set_map pp by force
  let ?l = "set (map fst pp)"
  from qj have qj': "qj  p ` {0 ..< nr}" using r by auto
  let ?v = "non_pivot_base A pp qj"
  let ?P = "p ` {0 ..< nr}"
  have dimv: "dim_vec ?v = nc" unfolding d by simp
  thus "?v  carrier_vec nc" unfolding carrier_vec_def by auto
  show vqj: "?v $ qj = 1" unfolding d using qj by auto
  { 
    fix qj'
    assume *: "qj' < nc" "qj  qj'" "qj'  snd ` set pp"
    hence "?map qj' = None" unfolding map_of_eq_None_iff by force
    hence "I qj' = 0" unfolding I_def by simp
    with * show "non_pivot_base A pp qj $ qj' = 0" 
      unfolding d by simp
  }    
  {
    fix i
    assume i: "i < nr"
    let ?I = "{j. ?map j = Some i}"
    have "row A i  ?v = 0" 
    proof -
      have id: "({0..<nc}  ?P)  ({0..<nc} - ?P) = {0..<nc}" by auto
      let ?e = "λ j. row A i $ j * ?v $ j"
      let ?e' = "λ j. (if ?map j = Some i then - A $$ (i, qj) else 0)"
      {
        fix j
        assume j: "j < nc" "j  ?P"
        then obtain ii where ii: "ii < nr" and jpi: "j = p ii" and pii: "p ii < nc" by auto
        hence mem: "(ii,j)  set pp" and "(j,ii)  set ?spp" by (auto simp: pp)        
        from map_of_is_SomeI[OF dist this(2)] 
        have map: "?map j = Some ii" by auto
        from mem j qj have jqj: "j  qj" by force
        note p = pivot'(4-5)[OF ii pii]
        define start where "start = ?e j"
        have "start = A $$ (i,j) * ?v $ j" using j i A by (auto simp: start_def)
        also have "A $$ (i,j) = A $$ (i, p ii)" unfolding jpi ..
        also have " = (if i = ii then 1 else 0)" using p(1) p(2)[OF i] by auto
        also have " * ?v $ j = (if i = ii then ?v $ j else 0)" by simp
        also have "?v $ j = I j" unfolding d 
          using j jqj A by auto
        also have "I j = - A $$ (ii, qj)" unfolding I_def map by simp
        finally have "?e j = ?e' j" 
          unfolding start_def map by auto
      } note piv = this
      have "row A i  ?v = ( j = 0..<nc. ?e j)" unfolding row_def scalar_prod_def dimv ..
      also have " = sum ?e ({0..<nc}  ?P) + sum ?e ({0..<nc} - ?P)"
        by (subst sum.union_disjoint[symmetric], auto simp: id)
      also have "sum ?e ({0..<nc} - ?P) = ?e qj + sum ?e ({0 ..<nc} - ?P - {qj})"
        by (rule sum.remove, insert qj qj', auto)
      also have "?e qj = row A i $ qj" unfolding vqj by simp
      also have "row A i $ qj = A $$ (i, qj)" using i A qj by auto
      also have "sum ?e ({0 ..<nc} - ?P - {qj}) = 0"
      proof (rule sum.neutral, intro ballI)
        fix j
        assume "j  {0 ..<nc} - ?P - {qj}"
        hence j: "j < nc" "j  ?P" "j  qj" "j  ?r" unfolding r by auto
        hence id: "map_of ?spp j = None" unfolding map_of_eq_None_iff by force
        have "?v $ j = I j" unfolding d using j by simp
        also have " = 0" unfolding I_def id by simp 
        finally show "row A i $ j * ?v $ j = 0" by simp
      qed
      also have "A $$ (i, qj) + 0 = A $$ (i, qj)" by simp
      also have "sum ?e ({0..<nc}  ?P) = sum ?e' ({0..<nc}  ?P)"
        by (rule sum.cong, insert piv, auto)
      also have "{0..<nc}  ?P = {0..<nc}  ?I  ?P  ({0..<nc} - ?I)  ?P" by auto
      also have "sum ?e' ({0..<nc}  ?I  ?P  ({0..<nc} - ?I)  ?P)
        = sum ?e' ({0..<nc}  ?I  ?P) + sum ?e' (({0..<nc} - ?I)  ?P)"
        by (rule sum.union_disjoint, auto)
      also have "sum ?e' (({0..<nc} - ?I)  ?P) = 0"
        by (rule sum.neutral, auto)
      also have "sum ?e' ({0..<nc}  ?I  ?P) = 
        sum (λ _. - A $$ (i, qj)) ({0..<nc}  ?I  ?P)"
        by (rule sum.cong, auto)
      also have " + 0 = " by simp
      also have "sum (λ _. - A $$ (i, qj)) ({0..<nc}  ?I  ?P) + A $$ (i, qj) = 0" 
      proof (cases "i  ?l")
        case False
        with pp(1) i have "p i = nc" by force
        from pivot'(2)[OF i, unfolded this, OF qj(1)] have z: "A $$ (i, qj) = 0" .
        show ?thesis 
          by (subst sum.neutral, auto simp: z)
      next
        case True
        then obtain j where mem: "(i,j)  set pp" and id: "(j,i)  set ?spp" by auto
        from map_of_is_SomeI[OF dist this(2)] have map: "?map j = Some i" .
        from pivot'(1)[OF i] have pi: "p i  nc" .
        with mem[unfolded pp] have j: "j = p i" "j < nc" by auto
        {
          fix j'
          assume "j'  ?I"
          hence "?map j' = Some i" by auto
          from map_of_SomeD[OF this] have "(i, j')  set pp" by auto
          with mem pp(2) have "j' = j" using map_of_is_SomeI by fastforce
        }
        with map have II: "?I = {j}" by blast
        have II: "{0..<nc}  ?I  ?P = {j}" unfolding II using mem[unfolded pp] i j by auto
        show ?thesis unfolding II by simp
      qed
      finally show "row A i  ?v = 0" .
    qed
  } note main = this
  show "A *v ?v = 0v nr"  
    by (rule eq_vecI, auto simp: dim main)
qed

lemma find_base_vector: assumes "snd ` set (pivot_positions A)  {0 ..< nc}"
  shows 
    "find_base_vector A  carrier_vec nc"
    "find_base_vector A  0v nc"
    "A *v find_base_vector A = 0v nr"
proof -
  define cands where "cands = filter (λ j. j  snd ` set (pivot_positions A)) [0 ..< nc]"
  from A have dim: "dim_row A = nr" "dim_col A = nc" by auto
  from ref[unfolded row_echelon_form_def] obtain p 
  where pivot: "pivot_fun A p nc" using dim by auto
  note piv = pivot_funD[OF dim(1) pivot]
  have "set cands  {}" using assms piv unfolding cands_def  pivot_positions[OF A pivot]
    by (auto simp: le_neq_implies_less)
  then obtain c cs where cands: "cands = c # cs" by (cases cands, auto)
  hence res: "find_base_vector A = non_pivot_base A (pivot_positions A) c"
    unfolding find_base_vector_def Let_def cands_def dim by auto
  from cands have "c  set cands" by auto
  hence c: "c < nc" "c  snd ` set (pivot_positions A)"
    unfolding cands_def by auto
  from non_pivot_base[OF this, folded res] c show
    "find_base_vector A  carrier_vec nc"
    "find_base_vector A  0v nc"
    "A *v find_base_vector A = 0v nr"
  by auto
qed
end

lemma row_echelon_form_imp_1_or_0_row: assumes A: "A  carrier_mat n n"
  and row: "row_echelon_form A"
  shows "A = 1m n  (n > 0  row A (n - 1) = 0v n)"
proof -
  from A have dim: "dim_row A = n" "dim_col A = n" by auto
  from row[unfolded row_echelon_form_def] A
  obtain f where pivot: "pivot_fun A f n" by auto
  note p = pivot_funD[OF dim(1) this]
  show ?thesis
  proof (cases " i < n. f i  i")
    case True
    then obtain i where i: "i < n" and fi: "f i  i" by auto
    note pb = pivot_bound[OF dim(1) pivot]
    from pb[of 0 i] i have "f i = n  i  f i" by auto
    with fi have fi: "f i = n  i < f i" by auto
    from i have n: "n - 1 = i + (n - i - 1)" by auto
    from pb[of i "n - i - 1", folded n] fi i p(1)[of "n - 1"] 
    have fn: "f (n - 1) = n" by auto
    from i have n0: "n > 0" and n1: "n - 1 < n" by auto
    from p(2)[OF n1, unfolded fn] have zero: " j. j < n  A $$ (n - 1, j) = 0" by auto
    show ?thesis
      by (rule disjI2[OF conjI[OF n0]], rule eq_vecI, insert zero A, auto)
  next
    case False
    {
      fix j
      assume j: "j < n"
      with False have id: "f j = j" by auto
      note pj = p[OF j, unfolded id]
      from pj(5)[OF j] pj(4)[OF j] 
      have " i. i < n  A $$ (i,j) = (if i = j then 1 else 0)" by auto
    } note id = this
    show ?thesis
      by (rule disjI1, rule eq_matI, subst id, insert A, auto)
  qed
qed

context
  fixes A :: "'a :: field mat" and n :: nat and p :: "nat  nat"
  assumes ref: "row_echelon_form A"
  and A: "A  carrier_mat n n"
  and 1: "A  1m n"
begin

lemma find_base_vector_not_1_pivot_positions: "snd ` set (pivot_positions A)  {0 ..< n}"
proof 
  let ?pp = "pivot_positions A"
  assume id: "snd ` set ?pp = {0 ..< n}"
  from A have dim: "dim_row A = n" "dim_col A = n" by auto
  let ?n = "n - 1"
  from row_echelon_form_imp_1_or_0_row[OF A ref] 1
  have *: "0 < n" and row: "row A ?n = 0v n" by auto
  from ref[unfolded row_echelon_form_def] obtain p 
    where pivot: "pivot_fun A p n" using dim by auto
  note pp = pivot_positions[OF A pivot]
  note piv = pivot_funD[OF dim(1) pivot]
  from * have n: "?n < n" by auto
  {
    
    assume "p ?n < n"
    with piv(4)[OF n this] row n A have False
      by (metis dim index_row(1) index_zero_vec(1) zero_neq_one)
  }
  with piv(1)[OF n] have pn: "p ?n = n" by fastforce
  hence "?n  fst ` set ?pp" unfolding pp by auto
  hence "fst ` set ?pp  {0 ..< n} - {?n}" unfolding pp by force
  also have "  {0 ..< n - 1}" by auto
  finally have "card (fst ` set ?pp)  card {0 ..< n - 1}" using card_mono by blast
  also have " = n - 1" by auto
  also have "card (fst ` set ?pp) = card (snd ` set ?pp)"
    unfolding set_map[symmetric] distinct_card[OF pp(2)] distinct_card[OF pp(3)] by simp
  also have " = n" unfolding id by simp
  finally show False using n by simp
qed
  
lemma find_base_vector_not_1: 
    "find_base_vector A  carrier_vec n"
    "find_base_vector A  0v n"
    "A *v find_base_vector A = 0v n"
  using find_base_vector[OF ref A find_base_vector_not_1_pivot_positions] .
end

lemma gauss_jordan: assumes A: "A  carrier_mat nr nc"
  and B: "B  carrier_mat nr nc2"
  and gauss: "gauss_jordan A B = (C,D)"
  shows "x  carrier_vec nc  (A *v x = 0v nr) = (C *v x = 0v nr)" (is "_  ?l = ?r")
    "X  carrier_mat nc nc2   (A * X = B) = (C * X = D)" (is " _  ?l2 = ?r2")
    "C  carrier_mat nr nc"
    "D  carrier_mat nr nc2"
proof -
  from gauss_jordan_transform[OF A B gauss, unfolded ring_mat_def Units_def, simplified]
  obtain P Q where P: "P  carrier_mat nr nr" and Q: "Q  carrier_mat nr nr"
    and inv: "Q * P = 1m nr" 
    and CPA: "C = P * A" 
    and DPB: "D = P * B" by auto
  from CPA P A show C: "C  carrier_mat nr nc" by auto
  from DPB P B show D: "D  carrier_mat nr nc2" by auto
  have "A = 1m nr * A" using A by simp
  also have " = Q * C" unfolding inv[symmetric] CPA using Q P A by simp
  finally have AQC: "A = Q * C" .
  have "B = 1m nr * B" using B by simp
  also have " = Q * D" unfolding inv[symmetric] DPB using Q P B by simp
  finally have BQD: "B = Q * D" .
  {
    assume x: "x  carrier_vec nc"
    {
      assume ?l
      from arg_cong[OF this, of "λ v. P *v v"] P A x have ?r unfolding CPA by auto
    }
    moreover
    {
      assume ?r
      from arg_cong[OF this, of "λ v. Q *v v"] Q C x have ?l unfolding AQC by auto
    }
    ultimately show "?l = ?r" by auto
  }
  {
    assume X: "X  carrier_mat nc nc2"
    {
      assume ?l2
      from arg_cong[OF this, of "λ X. P * X"] P A X have ?r2 unfolding CPA DPB by simp
    }
    moreover
    {
      assume ?r2
      from arg_cong[OF this, of "λ X. Q * X"] Q C X have ?l2 unfolding AQC BQD by simp
    }
    ultimately show "?l2 = ?r2" by auto
  }
qed

definition gauss_jordan_single :: "'a :: field mat  'a mat" where
  "gauss_jordan_single A = fst (gauss_jordan A (0m (dim_row A) 0))"

lemma gauss_jordan_single: assumes A: "A  carrier_mat nr nc"
  and gauss: "gauss_jordan_single A = C"
  shows "x  carrier_vec nc  (A *v x = 0v nr) = (C *v x = 0v nr)" 
    "C  carrier_mat nr nc"
    "row_echelon_form C"
    " P Q. C = P * A  P  carrier_mat nr nr  Q  carrier_mat nr nr  P * Q = 1m nr  Q * P = 1m nr" (is "?ex")
proof -
  from A gauss[unfolded gauss_jordan_single_def] obtain D where gauss: "gauss_jordan A (0m nr 0) = (C,D)"
    by (cases "gauss_jordan A (0m nr 0)", auto)
  from gauss_jordan[OF A zero_carrier_mat gauss] gauss_jordan_row_echelon[OF A gauss]
    gauss_jordan_transform[OF A zero_carrier_mat gauss, of "()"]
  show "x  carrier_vec nc  (A *v x = 0v nr) = (C *v x = 0v nr)" 
    "C  carrier_mat nr nc" "row_echelon_form C" ?ex unfolding Units_def ring_mat_def by auto
qed



lemma gauss_jordan_inverse_one_direction: 
  assumes A: "A  carrier_mat n n" and B: "B  carrier_mat n nc"
  and res: "gauss_jordan A B = (1m n, B')"
  shows "A  Units (ring_mat TYPE('a :: field) n b)"
  "B = 1m n  A * B' = 1m n  B' * A = 1m n"
proof -
  let ?R = "ring_mat TYPE('a) n b"
  let ?U = "Units ?R"
  interpret m: ring ?R by (rule ring_mat)
  from gauss_jordan_transform[OF A B res, of b]
  obtain P where P: "P  ?U" and id: "P * A = 1m n" and B': "B' = P * B" by auto
  from P have Pc: "P  carrier_mat n n" unfolding Units_def ring_mat_def by auto
  from m.Units_one_side_I(1)[of A P] A P id show Au: "A  ?U" unfolding ring_mat_def by auto
  assume B: "B = 1m n" 
  from B'[unfolded this] Pc have B': "B' = P" by auto
  show "A * B' = 1m n  B' * A = 1m n" unfolding B' 
    using m.Units_inv_comm[OF _ P Au] id by (auto simp: ring_mat_def)
qed

lemma gauss_jordan_inverse_other_direction: 
  assumes AU: "A  Units (ring_mat TYPE('a :: field) n b)" and B: "B  carrier_mat n nc"
  shows "fst (gauss_jordan A B) = 1m n"
proof -
  let ?R = "ring_mat TYPE('a) n b"
  let ?U = "Units ?R"
  interpret m: ring ?R by (rule ring_mat)
  from AU have A: "A  carrier_mat n n" unfolding Units_def ring_mat_def by auto
  obtain A' B' where res: "gauss_jordan A B = (A',B')" by force
  from gauss_jordan_transform[OF A B res, of b]
  obtain P where P: "P  ?U" and id: "A' = P * A" by auto
  from m.Units_m_closed[OF P AU]  have A': "A'  ?U" unfolding id ring_mat_def by auto
  hence A'c: "A'  carrier_mat n n" unfolding Units_def ring_mat_def by auto
  from A'[unfolded Units_def ring_mat_def] obtain IA' where IA': "IA'  carrier_mat n n"
    and IA: "A' * IA' = 1m n" by auto
  from row_echelon_form_imp_1_or_0_row[OF gauss_jordan_carrier(1)[OF A B res] gauss_jordan_row_echelon[OF A res]] 
  have choice: "A' = 1m n  0 < n  row A' (n - 1) = 0v n" .
  hence "A' = 1m n"
  proof 
    let ?n = "n - 1"
    assume "0 < n  row A' ?n = 0v n" 
    hence n: "?n < n" and row: "row A' ?n =  0v n" by auto
    have "1 = 1m n $$ (?n,?n)" using n by simp
    also have "1m n = A' * IA'" unfolding IA ..
    also have "(A' * IA') $$ (?n, ?n) = row A' ?n  col IA' ?n"
      using n IA' A'c by simp
    also have "row A' ?n = 0v n" unfolding row ..
    also have "0v n  col IA' ?n = 0" using IA' n by simp
    finally have "1 = (0 :: 'a)" by simp
    thus ?thesis by simp
  qed 
  with res show ?thesis by simp
qed

lemma gauss_jordan_compute_inverse:
  assumes A: "A  carrier_mat n n"
  and res: "gauss_jordan A (1m n) = (1m n, B')"
  shows "A * B' = 1m n" "B' * A = 1m n" "B'  carrier_mat n n"
proof -
  from gauss_jordan_inverse_one_direction(2)[OF A _ res refl, of n]
  show "A * B' = 1m n" "B' * A = 1m n" by auto
  from gauss_jordan_carrier(2)[OF A _ res, of n] show "B'  carrier_mat n n" by auto
qed

lemma gauss_jordan_check_invertable: assumes A: "A  carrier_mat n n" and B: "B  carrier_mat n nc"
  shows "(A  Units (ring_mat TYPE('a :: field) n b))  fst (gauss_jordan A B) = 1m n"
  (is "?l = ?r")
proof 
  assume ?l
  show ?r
    by (rule gauss_jordan_inverse_other_direction[OF ?l B])
next
  let ?g = "gauss_jordan A B"
  assume ?r
  then obtain B' where "?g = (1m n, B')" by (cases ?g, auto)
  from gauss_jordan_inverse_one_direction(1)[OF A B this]
  show ?l .
qed

definition mat_inverse :: "'a :: field mat  'a mat option" where 
  "mat_inverse A = (if dim_row A = dim_col A then
    let one = 1m (dim_row A) in
    (case gauss_jordan A one of
      (B, C)  if B = one then Some C else None) else None)"

lemma mat_inverse: assumes A: "A  carrier_mat n n"
  shows "mat_inverse A = None  A  Units (ring_mat TYPE('a :: field) n b)"
    "mat_inverse A = Some B  A * B = 1m n  B * A = 1m n  B  carrier_mat n n"
proof -
  let ?one = "1m n"
  obtain BB C where res: "gauss_jordan A ?one = (BB,C)" by force
  {
    assume "mat_inverse A = None"
    with res have "BB  ?one" unfolding mat_inverse_def using A by auto
    thus "A  Units (ring_mat TYPE('a :: field) n b)"
      using gauss_jordan_check_invertable[OF A, of ?one n] res by force
  }
  {
    assume "mat_inverse A = Some B"
    with res A have "BB = ?one" "C = B" unfolding mat_inverse_def
      by (auto split: if_splits option.splits)
    from gauss_jordan_compute_inverse[OF A res[unfolded this]]
    show "A * B = 1m n  B * A = 1m n  B  carrier_mat n n" by auto
  }
qed
end

Theory Gauss_Jordan_IArray_Impl

(*  
    Author:      Sebastiaan Joosten
                 René Thiemann 
                 Akihisa Yamada
    License:     BSD
*)
section ‹Code Generation for Basic Matrix Operations›

text ‹In this theory we provide efficient implementations
  for the elementary row-transformations. These are necessary since the default
  implementations would construct a whole new matrix in every step.›

theory Gauss_Jordan_IArray_Impl
imports 
  Polynomial_Interpolation.Missing_Unsorted
  Matrix_IArray_Impl
  Gauss_Jordan_Elimination
begin

lift_definition mat_swaprows_impl :: "nat  nat  'a mat_impl  'a mat_impl" is
  "λ i j (nr,nc,A). if i < nr  j < nr then 
  let Ai = IArray.sub A i; 
      Aj = IArray.sub A j;
      Arows = IArray.list_of A;
      A' = IArray.IArray (Arows [i := Aj, j := Ai])  
   in (nr,nc,A')
     else (nr,nc,A)" 
  by (auto split: if_splits)

lemma [code]: "mat_swaprows k l (mat_impl A) = (let nr = dim_row_impl A in
  if l < nr  k < nr then 
  mat_impl (mat_swaprows_impl k l A) else Code.abort (STR ''index out of bounds in mat_swaprows'') 
  (λ _. mat_swaprows k l (mat_impl A)))" (is "?l = ?r")
proof (cases "l < dim_row_impl A  k < dim_row_impl A")
  case True
  hence id: "?r = mat_impl (mat_swaprows_impl k l A)" by simp
  show ?thesis unfolding id unfolding mat_swaprows_def
  proof (rule eq_matI, goal_cases)
    case (1 i j)
    thus ?case using True
    proof (transfer, goal_cases)
      case (1 i k l A j)
      obtain nr nc rows where A: "A = (nr,nc,rows)" by (cases A, auto)
      from 1[unfolded A]
      have nr: "length (IArray.list_of rows) = nr"
        and nc: "IArray.all (λr. length (IArray.list_of r) = nc) rows"
        and ij: "i < nr" "j < nc" and ij': "(i < nr  j < nc) = True" 
        and l: "l < nr" "k < nr" by auto
      show ?case unfolding A prod.simps fst_conv o_def snd_conv Let_def mk_mat_def ij' if_True
        using ij nr nc l
        by (cases "k = i"; cases "l = i", auto)
    qed
  qed ((transfer, auto)+)
qed auto


lift_definition mat_multrow_gen_impl :: "('a  'a  'a)  nat  'a  'a mat_impl  'a mat_impl" is
  "λ mul k a (nr,nc,A). let Ak = IArray.sub A k; Arows = IArray.list_of A;
     Ak' = IArray.IArray (map (mul a) (IArray.list_of Ak));
     A' = IArray.IArray (Arows [k := Ak'])
     in (nr,nc,A')" 
proof (auto, goal_cases)
  case (1 mul k a nc b row)
  show ?case 
  proof (cases b)
    case (IArray rows)
    with 1 have "row  set rows  k < length rows  row = IArray (map (mul a) (IArray.list_of (rows ! k)))"
      by (cases "k < length rows", auto simp: set_list_update dest: in_set_takeD in_set_dropD)
    with 1 IArray show ?thesis by (cases, auto)
  qed
qed

lemma [code]: "mat_multrow_gen mul k a (mat_impl A) = mat_impl (mat_multrow_gen_impl mul k a A)"
  unfolding mat_multrow_gen_def
proof (rule eq_matI, goal_cases)
  case (1 i j)
  thus ?case 
  proof (transfer, goal_cases)
    case (1 i mul k a A j)
    obtain nr nc rows where A: "A = (nr,nc,rows)" by (cases A, auto)
    from 1[unfolded A]
    have nr: "length (IArray.list_of rows) = nr"
      and nc: "IArray.all (λr. length (IArray.list_of r) = nc) rows"
      and ij: "i < nr" "j < nc" and ij': "(i < nr  j < nc) = True" by auto
    have len: "j < length (IArray.list_of (IArray.list_of rows ! i))"
      using ij nc nr by (cases rows, auto)
    show ?case unfolding A prod.simps fst_conv o_def snd_conv Let_def mk_mat_def ij' if_True
      using ij nr nc 
      by (cases "k = i", auto simp: len)
  qed
qed ((transfer, auto)+)

lift_definition mat_addrow_gen_impl 
  :: "('a  'a  'a)  ('a  'a  'a)  'a  nat  nat  'a mat_impl  'a mat_impl" is
  "λ ad mul a k l (nr,nc,A). if l < nr then let Ak = IArray.sub A k; Al = IArray.sub A l;
     Ak' = IArray.of_fun (λ i. ad (mul a (Al !! i)) (Ak !! i)) (min (IArray.length Ak) (IArray.length Al));
     A' = IArray.of_fun (λ i. if i = k then Ak' else A !! i) (IArray.length A)
     in (nr,nc,A') else (nr,nc,A)" 
proof (goal_cases)
  case (1 ad mul a k l pp)
  obtain nr nc A where pp: "pp = (nr,nc,A)" by (cases pp)
  obtain rows where A: "A = IArray rows" by (cases A)
  from 1[unfolded pp A, simplified]
  have nr: "length rows = nr" and nc: " r. rset rows  length (IArray.list_of r) = nc" by auto  
  show ?case 
  proof (cases "l < nr")
    case False
    thus ?thesis unfolding pp A prod.simps using nr nc by auto
  next
    case True    
    thus ?thesis unfolding pp A prod.simps Let_def using nr nc
      by (auto simp: set_list_update dest: in_set_takeD in_set_dropD)
  qed
qed

lemma mat_addrow_gen_impl[code]: "mat_addrow_gen ad mul a k l (mat_impl A) = (if l < dim_row_impl A then
  mat_impl (mat_addrow_gen_impl ad mul a k l A) else Code.abort (STR ''index out of bounds in mat_addrow'') 
  (λ _. mat_addrow_gen ad mul a k l (mat_impl A)))" (is "?l = ?r")
proof (cases "l < dim_row_impl A")
  case True
  hence id: "?r = mat_impl (mat_addrow_gen_impl ad mul a k l A)" by simp
  show ?thesis unfolding id unfolding mat_addrow_gen_def
  proof (rule eq_matI, goal_cases)
    case (1 i j)
    thus ?case using True
    proof (transfer, goal_cases)
      case (1 i ad mul a k l A j)
      obtain nr nc rows where A: "A = (nr,nc,rows)" by (cases A, auto)
      from 1[unfolded A Let_def]
      have nr: "length (IArray.list_of rows) = nr"
        and nc: "IArray.all (λr. length (IArray.list_of r) = nc) rows"
        and ij: "i < nr" "j < nc" and ij': "(i < nr  j < nc) = True" 
        and l: "l < nr" by auto
      have len: "j < length (IArray.list_of (IArray.list_of rows ! i))"
        "j < length (IArray.list_of (IArray.list_of rows ! l))"
        using ij nc nr l by (cases rows, auto)+
      show ?case unfolding A prod.simps fst_conv o_def snd_conv Let_def mk_mat_def ij' if_True
        using ij nr nc l
        by (cases "k = i", auto simp: len)
    qed next
  qed ((transfer, auto simp:Let_def)+)
qed simp
 
lemma gauss_jordan_main_code[code]:
  "gauss_jordan_main A B i j = (let nr = dim_row A; nc = dim_col A in
    if i < nr  j < nc then let aij = A $$ (i,j) in if aij = 0 then
      (case [ i' . i' <- [Suc i ..< nr],  A $$ (i',j)  0] 
        of []  gauss_jordan_main A B i (Suc j)
         | (i' # _)  gauss_jordan_main (swaprows i i' A) (swaprows i i' B) i j)
      else if aij = 1 then let v = (λ i. A $$ (i,j)) in
        gauss_jordan_main 
        (eliminate_entries v A i j) (eliminate_entries v B i j) (Suc i) (Suc j)
      else let iaij = inverse aij; A' = multrow i iaij A; B' = multrow i iaij B;
        v = (λ i. A' $$ (i,j)) in gauss_jordan_main 
        (eliminate_entries v A' i j) (eliminate_entries v B' i j) (Suc i) (Suc j)
    else (A,B))" (is "?l = ?r")
proof -
  note simps = gauss_jordan_main.simps[of A B i j] Let_def
  let ?nr = "dim_row A" 
  let ?nc = "dim_col A"
  let ?A' = "multrow i (inverse (A $$ (i,j))) A" 
  let ?B' = "multrow i (inverse (A $$ (i,j))) B" 
  show ?thesis
  proof (cases "i < ?nr  j < ?nc  A $$ (i,j)  0  A $$ (i,j)  1")
    case False
    thus ?thesis unfolding simps by (auto split: if_splits)
  next
    case True
    from True have id: "?A' $$ (i,j) = 1" by auto
    from True have "?l = gauss_jordan_main ?A' ?B' i j" unfolding simps by (simp add: Let_def)
    also have " = ?r" unfolding Let_def gauss_jordan_main.simps[of ?A' ?B' i j] id 
      using True by simp
    finally show ?thesis .
  qed
qed 


end

Theory Column_Operations

(*  
    Author:      René Thiemann 
                 Akihisa Yamada
    License:     BSD
*)
section ‹Elementary Column Operations›

text ‹We define elementary column operations and also combine them with elementary
  row operations. These combined operations are the basis to perform operations which
  preserve similarity of matrices. They are applied later on to convert upper triangular
  matrices into Jordan normal form.›

theory Column_Operations
imports
  Gauss_Jordan_Elimination
begin

definition mat_multcol :: "nat  'a :: semiring_1  'a mat  'a mat" ("multcol") where
  "multcol k a A = mat (dim_row A) (dim_col A) 
     (λ (i,j). if k = j then a * A $$ (i,j) else A $$ (i,j))"

definition mat_swapcols :: "nat  nat  'a mat  'a mat" ("swapcols")where
  "swapcols k l A = mat (dim_row A) (dim_col A) 
    (λ (i,j). if k = j then A $$ (i,l) else if l = j then A $$ (i,k) else A $$ (i,j))"

definition mat_addcol_vec :: "nat  'a :: plus vec  'a mat  'a mat" where
  "mat_addcol_vec k v A = mat (dim_row A) (dim_col A) 
    (λ (i,j). if k = j then v $ i + A $$ (i,j) else A $$ (i,j))"

definition mat_addcol :: "'a :: semiring_1  nat  nat  'a mat  'a mat" ("addcol") where
  "addcol a k l A = mat (dim_row A) (dim_col A) 
    (λ (i,j). if k = j then a * A $$ (i,l) + A $$ (i,j) else A $$ (i,j))"

lemma index_mat_multcol[simp]: 
  "i < dim_row A  j < dim_col A  multcol k a A $$ (i,j) = (if k = j then a * A $$ (i,j) else A $$ (i,j))"
  "i < dim_row A  j < dim_col A  multcol j a A $$ (i,j) = a * A $$ (i,j)"
  "i < dim_row A  j < dim_col A  k  j  multcol k a A $$ (i,j) = A $$ (i,j)"
  "dim_row (multcol k a A) = dim_row A" "dim_col (multcol k a A) = dim_col A"
  unfolding mat_multcol_def by auto

lemma index_mat_swapcols[simp]: 
  "i < dim_row A  j < dim_col A  swapcols k l A $$ (i,j) = (if k = j then A $$ (i,l) else 
    if l = j then A $$ (i,k) else A $$ (i,j))"
  "dim_row (swapcols k l A) = dim_row A" "dim_col (swapcols k l A) = dim_col A"
  unfolding mat_swapcols_def by auto

lemma index_mat_addcol[simp]: 
  "i < dim_row A  j < dim_col A  addcol a k l A $$ (i,j) = (if k = j then 
    a * A $$ (i,l) + A $$ (i,j) else A $$ (i,j))"
  "i < dim_row A  j < dim_col A  addcol a j l A $$ (i,j) = a * A $$(i,l) + A$$(i,j)"
  "i < dim_row A  j < dim_col A  k  j  addcol a k l A $$ (i,j) = A $$(i,j)"
  "dim_row (addcol a k l A) = dim_row A" "dim_col (addcol a k l A) = dim_col A"
  unfolding mat_addcol_def by auto

text ‹Each column-operation can be seen as a multiplication of 
  an elementary matrix from the right›

lemma col_addrow: 
  "l  i  i < n  col (addrow_mat n a k l) i = unit_vec n i"
  "k < n  l < n  col (addrow_mat n a k l) l = a v unit_vec n k + unit_vec n l" 
  by (rule eq_vecI, auto)

lemma col_addcol[simp]:
  "k < dim_col A  l < dim_col A  col (addcol a k l A) k = a v col A l + col A k"
  by (rule eq_vecI;simp)

lemma addcol_mat: assumes A: "A  carrier_mat nr n" 
  and k: "k < n"
  shows "addcol (a :: 'a :: comm_semiring_1) l k A = A * addrow_mat n a k l"
  by (rule eq_matI, insert A k, auto simp: col_addrow
  scalar_prod_add_distrib[of _ n] scalar_prod_smult_distrib[of _ n])

lemma col_multrow:  "k  i  i < n  col (multrow_mat n k a) i = unit_vec n i"
  "k < n  col (multrow_mat n k a) k = a v unit_vec n k"
  by (rule eq_vecI, auto)

lemma multcol_mat: assumes A: "(A :: 'a :: comm_ring_1 mat)  carrier_mat nr n"
  shows "multcol k a A = A * multrow_mat n k a"
  by (rule eq_matI, insert A, auto simp: col_multrow smult_scalar_prod_distrib[of _ n])

lemma col_swaprows: 
  "l < n  col (swaprows_mat n l l) l = unit_vec n l"
  "i  k  i  l  i < n  col (swaprows_mat n k l) i = unit_vec n i"
  "k < n  l < n  col (swaprows_mat n k l) l = unit_vec n k"
  "k < n  l < n  col (swaprows_mat n k l) k = unit_vec n l"
  by (rule eq_vecI, auto)

lemma swapcols_mat: assumes A: "A  carrier_mat nr n" and k: "k < n" "l < n"
  shows "swapcols k l A = A * swaprows_mat n k l"
  by (rule eq_matI, insert A k, auto simp: col_swaprows)

text ‹Combining row and column-operations yields similarity transformations.›

definition add_col_sub_row :: "'a :: ring_1  nat  nat  'a mat  'a mat"  where
  "add_col_sub_row a k l A = addrow (- a) k l (addcol a l k A)"

definition mult_col_div_row :: "'a :: field  nat  'a mat  'a mat" where
  "mult_col_div_row a k A = multrow k (inverse a) (multcol k a A)"

definition swap_cols_rows :: "nat  nat  'a mat  'a mat" where
  "swap_cols_rows k l A = swaprows k l (swapcols k l A)"


lemma add_col_sub_row_carrier[simp]: 
  "dim_row (add_col_sub_row a k l A) = dim_row A"
  "dim_col (add_col_sub_row a k l A) = dim_col A"
  "A  carrier_mat n n  add_col_sub_row a k l A  carrier_mat n n"
  unfolding add_col_sub_row_def carrier_mat_def by auto

lemma add_col_sub_index_row[simp]: 
  "i < dim_row A  i < dim_col A  j < dim_row A  j < dim_col A  l < dim_row A 
     add_col_sub_row a k l A $$ (i,j) = (if 
      i = k  j = l then A $$ (i, j) + a * A $$ (i, i) - a * a * A $$ (j, i) - a * A $$ (j, j) else if
      i = k  j  l then A $$ (i, j) - a * A $$ (l, j) else if
      i  k  j = l then A $$ (i, j) + a * A $$ (i, k) else A $$ (i,j))"
  unfolding add_col_sub_row_def by (auto simp: field_simps)

lemma mult_col_div_index_row[simp]: 
  "i < dim_row A  i < dim_col A  j < dim_row A  j < dim_col A  a  0
     mult_col_div_row a k A $$ (i,j) = (if 
      i = k  j  i then inverse a * A $$ (i, j) else if
      j = k  j  i then a * A $$ (i, j) else A $$ (i,j))"
  unfolding mult_col_div_row_def by auto

lemma mult_col_div_row_carrier[simp]: 
  "dim_row (mult_col_div_row a k A) = dim_row A"
  "dim_col (mult_col_div_row a k A) = dim_col A"
  "A  carrier_mat n n  mult_col_div_row a k A  carrier_mat n n"
  unfolding mult_col_div_row_def carrier_mat_def by auto

lemma swap_cols_rows_carrier[simp]: 
  "dim_row (swap_cols_rows k l A) = dim_row A"
  "dim_col (swap_cols_rows k l A) = dim_col A"
  "A  carrier_mat n n  swap_cols_rows k l A  carrier_mat n n"
  unfolding swap_cols_rows_def carrier_mat_def by auto

lemma swap_cols_rows_index[simp]: 
  "i < dim_row A  i < dim_col A  j < dim_row A  j < dim_col A  a < dim_row A  b < dim_row A 
     swap_cols_rows a b A $$ (i,j) = A $$ (if i = a then b else if i = b then a else i,
     if j = a then b else if j = b then a else j)"
  unfolding swap_cols_rows_def 
  by auto 

lemma add_col_sub_row_similar: assumes A: "A  carrier_mat n n" and kl: "k < n" "l < n" "k  l"
  shows "similar_mat (add_col_sub_row a k l A) (A :: 'a :: comm_ring_1 mat)"
proof (rule similar_matI)
  let ?P = "addrow_mat n (-a) k l"
  let ?Q = "addrow_mat n a k l"
  let ?B = "add_col_sub_row a k l A"
  show carr: "{?B, A, ?P, ?Q}  carrier_mat n n" using A by auto
  show "?Q * ?P = 1m n" by (rule addrow_mat_inv[OF kl])
  show "?P * ?Q = 1m n" using addrow_mat_inv[OF kl, of "-a"] by simp
  have col: "addcol a l k A = A * ?Q"
    by (rule addcol_mat[OF A kl(1)])
  have "?B = ?P * (A * ?Q)" unfolding add_col_sub_row_def col
    by (rule addrow_mat[OF _ kl(2), of _ n], insert A, simp)
  thus "?B = ?P * A * ?Q" using carr by (simp add: assoc_mult_mat[of _ n n _ n _ n])
qed

lemma mult_col_div_row_similar: assumes A: "A  carrier_mat n n" and ak: "k < n" "a  0"
  shows "similar_mat (mult_col_div_row a k A) A"
proof (rule similar_matI)
  let ?P = "multrow_mat n k (inverse a)"
  let ?Q = "multrow_mat n k a"
  let ?B = "mult_col_div_row a k A"
  show carr: "{?B, A, ?P, ?Q}  carrier_mat n n" using A by auto
  show "?Q * ?P = 1m n" by (rule multrow_mat_inv[OF ak])
  show "?P * ?Q = 1m n" using multrow_mat_inv[OF ak(1), of "inverse a"] ak(2) by simp
  have col: "multcol k a A = A * ?Q"
    by (rule multcol_mat[OF A])
  have "?B = ?P * (A * ?Q)" unfolding mult_col_div_row_def col
    by (rule multrow_mat[of _ n n], insert A, simp)
  thus "?B = ?P * A * ?Q" using carr by (simp add: assoc_mult_mat[of _ n n _ n _ n])
qed

lemma swap_cols_rows_similar: assumes A: "A  carrier_mat n n" and kl: "k < n" "l < n"
  shows "similar_mat (swap_cols_rows k l A) A"
proof (rule similar_matI)
  let ?P = "swaprows_mat n k l"
  let ?B = "swap_cols_rows k l A"
  show carr: "{?B, A, ?P, ?P}  carrier_mat n n" using A by auto
  show "?P * ?P = 1m n" by (rule swaprows_mat_inv[OF kl])
  show "?P * ?P = 1m n" by fact
  have col: "swapcols k l A = A * ?P"
    by (rule swapcols_mat[OF A kl])
  have "?B = ?P * (A * ?P)" unfolding swap_cols_rows_def col
    by (rule swaprows_mat[of _ n n ], insert A kl, auto)
  thus "?B = ?P * A * ?P" using carr by (simp add: assoc_mult_mat[of _ n n _ n _ n])
qed

(* THIS LINE SEPARATES AFP-ENTRY FROM NEWER DEVELOPMENTS *)

lemma swapcols_carrier[simp]: "(swapcols l k A  carrier_mat n m) = (A  carrier_mat n m)"
  unfolding mat_swapcols_def carrier_mat_def by auto

fun swap_row_to_front :: "'a mat  nat  'a mat" where
  "swap_row_to_front A 0 = A"
| "swap_row_to_front A (Suc I) = swap_row_to_front (swaprows I (Suc I) A) I"

fun swap_col_to_front :: "'a mat  nat  'a mat" where
  "swap_col_to_front A 0 = A"
| "swap_col_to_front A (Suc I) = swap_col_to_front (swapcols I (Suc I) A) I"

lemma swap_row_to_front_result: "A  carrier_mat n m  I < n  swap_row_to_front A I = 
  mat n m (λ (i,j). if i = 0 then A $$ (I,j)
  else if i  I then A $$ (i - 1, j) else A $$ (i,j))"
proof (induct I arbitrary: A)
  case 0
  thus ?case
    by (intro eq_matI, auto)
next
  case (Suc I A)
  from Suc(3) have I: "I < n" by auto
  let ?I = "Suc I"
  let ?A = "swaprows I ?I A"
  have AA: "?A  carrier_mat n m" using Suc(2) by simp
  have "swap_row_to_front A (Suc I) = swap_row_to_front ?A I" by simp
  also have " = mat n m
   (λ(i, j). if i = 0 then ?A $$ (I, j)
       else if i  I then ?A $$ (i - 1, j) else ?A $$ (i, j))" 
     using Suc(1)[OF AA I] by simp
  also have " = mat n m
   (λ(i, j). if i = 0 then A $$ (?I, j)
       else if i  ?I then A $$ (i - 1, j) else A $$ (i, j))" 
    by (rule eq_matI, insert I Suc(2), auto)
  finally show ?case .
qed


lemma swap_col_to_front_result: "A  carrier_mat n m  J < m  swap_col_to_front A J = 
  mat n m (λ (i,j). if j = 0 then A $$ (i,J)
  else if j  J then A $$ (i, j-1) else A $$ (i,j))"
proof (induct J arbitrary: A)
  case 0
  thus ?case
    by (intro eq_matI, auto)
next
  case (Suc J A)
  from Suc(3) have J: "J < m" by auto
  let ?J = "Suc J"
  let ?A = "swapcols J ?J A"
  have AA: "?A  carrier_mat n m" using Suc(2) by simp
  have "swap_col_to_front A (Suc J) = swap_col_to_front ?A J" by simp
  also have " = mat n m
   (λ(i, j). if j = 0 then ?A $$ (i, J)
          else if j  J then ?A $$ (i, j - 1) else ?A $$ (i, j))" 
     using Suc(1)[OF AA J] by simp
  also have " = mat n m
   (λ(i, j). if j = 0 then A $$ (i, ?J)
          else if j  ?J then A $$ (i, j - 1) else A $$ (i, j))" 
    by (rule eq_matI, insert J Suc(2), auto)
  finally show ?case .
qed

lemma swapcols_is_transp_swap_rows: assumes A: "A  carrier_mat n m" "k < m" "l < m"
  shows "swapcols k l A = transpose_mat (swaprows k l (transpose_mat A))"
  using assms by (intro eq_matI, auto)



end

Theory Determinant

(*  
    Author:      René Thiemann 
                 Akihisa Yamada
    License:     BSD
*)
section ‹Determinants›

text ‹Most of the following definitions and proofs on determinants have been copied and adapted 
  from ~~/src/HOL/Multivariate-Analysis/Determinants.thy.

Exceptions are \emph{det-identical-rows}.

We further generalized some lemmas, e.g., that the determinant is 0 iff the kernel of a matrix
is non-empty is available for integral domains, not just for fields.›

theory Determinant
imports 
  Missing_Permutations
  Column_Operations
  "HOL-Computational_Algebra.Polynomial_Factorial" (* Only for to_fract. Probably not the right place. *)
  Polynomial_Interpolation.Ring_Hom
  Polynomial_Interpolation.Missing_Unsorted
begin

definition det:: "'a mat  'a :: comm_ring_1" where
  "det A = (if dim_row A = dim_col A then ( p  {p. p permutes {0 ..< dim_row A}}. 
     signof p * ( i = 0 ..< dim_row A. A $$ (i, p i))) else 0)"

lemma(in ring_hom) hom_signof[simp]: "hom (signof p) = signof p"
  unfolding signof_def by (auto simp: hom_distribs)

lemma(in comm_ring_hom) hom_det[simp]: "det (map_mat hom A) = hom (det A)"
  unfolding det_def by (auto simp: hom_distribs)

lemma det_def': "A  carrier_mat n n  
  det A = ( p  {p. p permutes {0 ..< n}}. 
     signof p * ( i = 0 ..< n. A $$ (i, p i)))" unfolding det_def by auto

lemma det_smult[simp]: "det (a m A) = a ^ dim_col A * det A"
proof -
  have [simp]: "(i = 0..<dim_col A. a) = a ^ dim_col A" by(subst prod_constant;simp)
  show ?thesis
  unfolding det_def
  unfolding index_smult_mat
  by (auto intro: sum.cong simp: sum_distrib_left prod.distrib)
qed

lemma det_transpose: assumes A: "A  carrier_mat n n"
  shows "det (transpose_mat A) = det A"
proof -
  let ?di = "λA i j. A $$ (i,j)"
  let ?U = "{0 ..< n}"
  have fU: "finite ?U" by simp
  let ?inv = "Hilbert_Choice.inv"
  {
    fix p
    assume p: "p  {p. p permutes ?U}"
    from p have pU: "p permutes ?U"
      by blast
    have sth: "signof (?inv p) = signof p"
      by (rule signof_inv[OF _ pU], simp)
    from permutes_inj[OF pU]
    have pi: "inj_on p ?U"
      by (blast intro: subset_inj_on)
    let ?f = "λi. transpose_mat A $$ (i, ?inv p i)"
    note pU_U = permutes_image[OF pU]
    note [simp] = permutes_less[OF pU]
    have "prod ?f ?U = prod ?f (p ` ?U)"
      using pU_U by simp
    also have " = prod (?f  p) ?U"
      by (rule prod.reindex[OF pi])
    also have " = prod (λi. A $$ (i, p i)) ?U"
      by (rule prod.cong, insert A, auto)
    finally have "signof (?inv p) * prod ?f ?U =
      signof p * prod (λi. A $$ (i, p i)) ?U"
      unfolding sth by simp
  }
  then show ?thesis
    unfolding det_def using A
    by (simp, subst sum_permutations_inverse, intro sum.cong, auto)
qed

lemma det_col:
  assumes A: "A  carrier_mat n n"
  shows "det A = ( p | p permutes {0 ..< n}. signof p * (j<n. A $$ (p j, j)))"
    (is "_ = (sum (λp. _ * ?prod p) ?P)")
proof -
  let ?i = "Hilbert_Choice.inv"
  let ?N = "{0 ..< n}"
  let ?f = "λp. signof p * ?prod p"
  let ?prod' = "λp. j<n. A $$ (j, ?i p j)"
  let ?prod'' = "λp. j<n. A $$ (j, p j)"
  let ?f' = "λp. signof (?i p) * ?prod' p"
  let ?f'' = "λp. signof p * ?prod'' p"
  let ?P' = "{ ?i p | p. p permutes ?N }"
  have [simp]: "{0..<n} = {..<n}" by auto
  have "sum ?f ?P = sum ?f' ?P"
  proof (rule sum.cong[OF refl],unfold mem_Collect_eq)
      fix p assume p: "p permutes ?N"
      have [simp]: "?prod p = ?prod' p"
        using permutes_prod[OF p, of "λx y. A $$ (x,y)"] by auto
      have [simp]: "signof p = signof (?i p)"
        apply(rule signof_inv[symmetric]) using p by auto
      show "?f p = ?f' p" by auto
  qed
  also have "... = sum ?f' ?P'"
    by (rule sum.cong[OF image_inverse_permutations[symmetric]],auto)
  also have "... = sum ?f'' ?P"
    unfolding sum.reindex[OF inv_inj_on_permutes,unfolded image_Collect]
    unfolding o_def
    apply (rule sum.cong[OF refl])
    using inv_inv_eq[OF permutes_bij] by force
  finally show ?thesis unfolding det_def'[OF A] by auto
qed

lemma mat_det_left_def: assumes A: "A  carrier_mat n n"
  shows "det A = (p{p. p permutes {0..<dim_row A}}. signof p * (i = 0 ..< dim_row A. A $$ (p i, i)))"
proof -
  have cong: " a b c. b = c  a * b = a * c" by simp
  show ?thesis
  unfolding det_transpose[OF A, symmetric]
  unfolding det_def index_transpose_mat using A by simp
qed

lemma det_upper_triangular:
  assumes ut: "upper_triangular A"
  and m: "A  carrier_mat n n"
  shows "det A = prod_list (diag_mat A)"
proof -
  note det_def = det_def'[OF m]
  let ?U = "{0..<n}"
  let ?PU = "{p. p permutes ?U}"
  let ?pp = "λp. signof p * ( i = 0 ..< n. A $$ (i, p i))"
  have fU: "finite ?U"
    by simp
  from finite_permutations[OF fU] have fPU: "finite ?PU" .
  have id0: "{id}  ?PU"
    by (auto simp add: permutes_id)
  {
    fix p
    assume p: "p  ?PU - {id}"
    from p have pU: "p permutes ?U" and pid: "p  id"
      by blast+
    from permutes_natset_ge[OF pU] pid obtain i where i: "p i < i" and "i < n" 
      by fastforce
    from upper_triangularD[OF ut i] i < n m
    have ex:"i  ?U. A $$ (i,p i) = 0" by auto
    have "( i = 0 ..< n. A $$ (i, p i)) = 0" 
      by (rule prod_zero[OF fU ex])
    hence "?pp p = 0" by simp
  }
  then have p0: " p. p  ?PU - {id}  ?pp p = 0"
    by blast
  from m have dim: "dim_row A = n" by simp
  have "det A = ( p  ?PU. ?pp p)" unfolding det_def by auto
  also have " = ?pp id + ( p  ?PU - {id}. ?pp p)"
    by (rule sum.remove, insert id0 fPU m, auto simp: p0)
  also have "( p  ?PU - {id}. ?pp p) = 0"
    by (rule sum.neutral, insert fPU, auto simp: p0)
  finally show ?thesis using m by (auto simp: prod_list_diag_prod)
qed

lemma det_one[simp]: "det (1m n) = 1"
proof -
  have "det (1m n) = prod_list (diag_mat (1m n))"
    by (rule det_upper_triangular[of _ n], auto)
  also have " = 1" by (induct n, auto)
  finally show ?thesis .
qed

lemma det_zero[simp]: assumes "n > 0" shows "det (0m n n) = 0"
proof -
  have "det (0m n n) = prod_list (diag_mat (0m n n))"
    by (rule det_upper_triangular[of _ n], auto)
  also have " = 0" using n > 0 by (cases n, auto)
  finally show ?thesis .
qed

lemma det_dim_zero[simp]: "A  carrier_mat 0 0  det A = 1"
  unfolding det_def carrier_mat_def signof_def sign_def by auto
 

lemma det_lower_triangular:
  assumes ld: "i j. i < j  j < n  A $$ (i,j) = 0"
  and m: "A  carrier_mat n n"
  shows "det A = prod_list (diag_mat A)"
proof -
  have "det A = det (transpose_mat A)" using det_transpose[OF m] by simp
  also have " = prod_list (diag_mat (transpose_mat A))"
    by (rule det_upper_triangular, insert m ld, auto)
  finally show ?thesis using m by simp
qed

lemma det_permute_rows: assumes A: "A  carrier_mat n n"
  and p: "p permutes {0 ..< (n :: nat)}"
  shows "det (mat n n (λ (i,j). A $$ (p i, j))) = signof p * det A"
proof -
  let ?U = "{0 ..< (n :: nat)}"
  have cong: " a b c. b = c  a * b = a * c" by auto
  have "det (mat n n (λ (i,j). A $$ (p i, j))) = 
    ( q  {q . q permutes ?U}. signof q * ( i  ?U. A $$ (p i, q i)))"
    unfolding det_def using A p by auto
  also have " = ( q  {q . q permutes ?U}. signof (q  p) * ( i  ?U. A $$ (p i, (q  p) i)))"
    by (rule sum_permutations_compose_right[OF p])
  finally have 1: "det (mat n n (λ (i,j). A $$ (p i, j)))
    = ( q  {q . q permutes ?U}. signof (q  p) * ( i  ?U. A $$ (p i, (q  p) i)))" .
  have 2: "signof p * det A = 
    ( q{q. q permutes ?U}. signof p * signof q * (i ?U. A $$ (i, q i)))"
    unfolding det_def'[OF A] sum_distrib_left by (simp add: ac_simps)
  show ?thesis unfolding 1 2
  proof (rule sum.cong, insert p A, auto)
    fix q
    assume q: "q permutes ?U"
    let ?inv = "Hilbert_Choice.inv"
    from permutes_inv[OF p] have ip: "?inv p permutes ?U" .
    have "prod (λi. A $$ (p i, (q  p) i)) ?U = 
      prod (λi. A $$ ((p  ?inv p) i, (q  (p  ?inv p)) i)) ?U" unfolding o_def
      by (rule trans[OF prod.permute[OF ip] prod.cong], insert A p q, auto)
    also have " = prod (λi. A$$(i,q i)) ?U"
      by (simp only: o_def permutes_inverses[OF p])
    finally have thp: "prod (λi. A $$ (p i, (q  p) i)) ?U = prod (λi. A$$(i,q i)) ?U" .      
    show "signof (q  p) * (i{0..<n}. A $$ (p i, q (p i))) =
         signof p * signof q * (i{0..<n}. A $$ (i, q i))"
      unfolding thp[symmetric] signof_compose[OF q p]
      by (simp add: ac_simps)
  qed
qed

lemma det_multrow_mat: assumes k: "k < n"
  shows "det (multrow_mat n k a) = a"
proof (rule trans[OF det_lower_triangular[of n]], unfold prod_list_diag_prod)
  let ?f = "λ i. multrow_mat n k a $$ (i, i)"
  have "(i{0..<n}. ?f i) = ?f k * (i{0..<n} - {k}. ?f i)"
    by (rule prod.remove, insert k, auto)
  also have "(i{0..<n} - {k}. ?f i) = 1" 
    by (rule prod.neutral, auto)
  finally show "(i{0..<dim_row (multrow_mat n k a)}. ?f i) = a" using k by simp
qed (insert k, auto)

lemma swap_rows_mat_eq_permute: 
  "k < n  l < n  swaprows_mat n k l = mat n n (λ(i, j). 1m n $$ (Fun.swap k l id i, j))"
  by (rule eq_matI) (auto simp add: swap_id_eq)

lemma det_swaprows_mat: assumes k: "k < n" and l: "l < n" and kl: "k  l"
  shows "det (swaprows_mat n k l) = - 1"
proof -
  let ?n = "{0 ..< (n :: nat)}"
  let ?p = "Fun.swap k l id"
  have p: "?p permutes ?n"
    by (rule permutes_swap_id, insert k l, auto)
  show ?thesis
    by (rule trans[OF trans[OF _ det_permute_rows[OF one_carrier_mat[of n] p]]],
    subst swap_rows_mat_eq_permute[OF k l], auto simp: signof_def sign_swap_id kl)
qed
  
lemma det_addrow_mat: 
  assumes l: "k  l"
  shows "det (addrow_mat n a k l) = 1"
proof -
  have "det (addrow_mat n a k l) = prod_list (diag_mat (addrow_mat n a k l))"
  proof (cases "k < l")
    case True
    show ?thesis
      by (rule det_upper_triangular[of _ n], insert True, auto intro!: upper_triangularI)
  next
    case False
    show ?thesis
      by (rule det_lower_triangular[of n], insert False, auto)
  qed
  also have " = 1" unfolding prod_list_diag_prod
    by (rule prod.neutral, insert l, auto)
  finally show ?thesis .
qed

text ‹The following proof is new, as it does not use $2 \neq 0$ as in Multivariate-Analysis.›

lemma det_identical_rows:
  assumes A: "A  carrier_mat n n"  
    and ij: "i  j"
    and i: "i < n" and j: "j < n"
    and r: "row A i = row A j"
  shows "det A = 0"
proof-
  let ?p = "Fun.swap i j id"
  let ?n = "{0 ..< n}"
  have sp: "signof ?p = - 1" "sign ?p = -1" unfolding signof_def using ij
    by (auto simp add: sign_swap_id)
  let ?f = "λ p. signof p * (i?n. A $$ (p i, i))"
  let ?all = "{p. p permutes ?n}"
  let ?one = "{p. p permutes ?n  sign p = 1}"
  let ?none = "{p. p permutes ?n  sign p  1}"
  let ?pone = "(λ p. ?p o p) ` ?one"
  have split: "?one  ?none = ?all" by auto
  have p: "?p permutes ?n" by (rule permutes_swap_id, insert i j, auto)
  from permutes_inj[OF p] have injp: "inj ?p" by auto
  {
    fix q
    assume q: "q permutes ?n"
    have "(k?n. A $$ (?p (q k), k)) = (k?n. A $$ (q k, k))"
    proof (rule prod.cong)
      fix k
      assume k: "k  ?n"
      from r have row: "row A i $ k = row A j $ k" by simp
      hence "A $$ (i,k) = A $$ (j,k)" using k i j A by auto
      thus "A $$ (?p (q k), k) = A $$ (q k, k)"
        by (cases "q k = i", auto, cases "q k = j", auto)
    qed (insert A q, auto)
  } note * = this
  have pp: " q. q permutes ?n  permutation q" unfolding 
    permutation_permutes by auto
  have "det A = (p ?one  ?none. ?f p)"
    using A unfolding mat_det_left_def[OF A] split by simp
  also have " = (p ?one. ?f p) + (p ?none. ?f p)"
    by (rule sum.union_disjoint, insert A, auto simp: finite_permutations)
  also have "?none = ?pone" 
  proof -
    {
      fix q
      assume "q  ?none"
      hence q: "q permutes ?n" and sq: "sign q = -1" unfolding sign_def by auto
      from permutes_compose[OF q p] sign_compose[OF pp[OF p] pp[OF q], unfolded sp sq]
      have "?p o q  ?one" by auto
      hence "?p o (?p o q)  ?pone" by auto
      also have "?p o (?p o q) = q"
        by (auto simp: swap_id_eq)
      finally have "q  ?pone" .
    }
    moreover
    {
      fix pq
      assume "pq  ?pone"
      then obtain q where q: "q  ?one" and pq: "pq = ?p o q" by auto
      from q have q: "q permutes ?n" and sq: "sign q = 1" by auto
      from sign_compose[OF pp[OF p] pp[OF q], unfolded sq sp] have spq: "sign pq = -1" unfolding pq by auto
      from permutes_compose[OF q p] have pq: "pq permutes ?n" unfolding pq by auto
      from pq spq have "pq  ?none" by auto
    }
    ultimately
    show ?thesis by blast
  qed  
  also have "(p ?pone. ?f p) = (p ?one. ?f (?p o p))"
  proof (rule trans[OF sum.reindex])
    show "inj_on ((∘) ?p) ?one" 
      using fun.inj_map[OF injp] unfolding inj_on_def by auto
  qed simp
  also have "(p ?one. ?f p) + (p ?one. ?f (?p o p))
    = (p ?one. ?f p + ?f (?p o p))"
    by (rule sum.distrib[symmetric])
  also have " = 0"
    by (rule sum.neutral, insert A, auto simp: 
      sp sign_compose[OF pp[OF p] pp] ij signof_def finite_permutations *)
  finally show ?thesis .
qed

lemma det_row_0: assumes k: "k < n"
  and c: "c  {0 ..< n}  carrier_vec n"
  shows "det (matr n n (λi. if i = k then 0v n else c i)) = 0"
proof -
  {
    fix p
    assume p: "p permutes {0 ..< n}"
    have "(i{0..<n}. matr n n (λi. if i = k then 0v n else c i) $$ (i, p i)) = 0" 
      by (rule prod_zero[OF _ bexI[of _ k]], 
      insert k p c[unfolded carrier_vec_def], auto)
  }
  thus ?thesis unfolding det_def by simp
qed

lemma det_row_add: 
  assumes abc: "a k  carrier_vec n" "b k  carrier_vec n" "c  {0..<n}  carrier_vec n"
    and k: "k < n"
  shows "det(matr n n (λ i. if i = k then a i + b i else c i)) =
    det(matr n n (λ i. if i = k then a i else c i)) +
    det(matr n n (λ i. if i = k then b i else c i))"
  (is "?lhs = ?rhs")
proof -
  let ?n = "{0..<n}"
  let ?m = "λ a b p i. matr n n (λi. if i = k then a i else b i) $$ (i, p i)"
  let ?c = "λ p i. matr n n c $$ (i, p i)"
  let ?ab = "λ i. a i + b i"
  note intros = add_carrier_vec[of _ n]
  have "?rhs = (p{p. p permutes ?n}. 
    signof p * (i?n. ?m a c p i)) + (p{p. p permutes ?n}. signof p * (i?n. ?m b c p i))"
    unfolding det_def by simp
  also have " = (p{p. p permutes ?n}. signof p * (i?n. ?m a c p i) +  signof p * (i?n. ?m b c p i))"
    by (rule sum.distrib[symmetric])
  also have " = (p{p. p permutes ?n}. signof p * (i?n. ?m ?ab c p i))"
  proof (rule sum.cong, force)
    fix p
    assume "p  {p. p permutes ?n}"
    hence p: "p permutes ?n" by simp
    show "signof p * (i?n. ?m a c p i) + signof p * (i?n. ?m b c p i) = 
      signof p * (i?n. ?m ?ab c p i)"
      unfolding distrib_left[symmetric]
    proof (rule arg_cong[of _ _ "λ a. signof p * a"])
      from k have f: "finite ?n" and k': "k  ?n" by auto
      let ?nk = "?n - {k}"
      note split = prod.remove[OF f k']
      have id1: "(i?n. ?m a c p i) = ?m a c p k * (i?nk. ?m a c p i)"
        by (rule split)
      have id2: "(i?n. ?m b c p i) = ?m b c p k * (i?nk. ?m b c p i)"
        by (rule split)
      have id3: "(i?n. ?m ?ab c p i) = ?m ?ab c p k * (i?nk. ?m ?ab c p i)"
        by (rule split)
      have id: " a. (i?nk. ?m a c p i) = (i?nk. ?c p i)"
        by (rule prod.cong, insert abc k p, auto intro!: intros)
      have ab: "?ab k  carrier_vec n" using abc by (auto intro: intros)
      {
        fix f
        assume "f k  (carrier_vec n :: 'a vec set)"
        hence "matr n n (λi. if i = k then f i else c i) $$ (k, p k) = f k $ p k"
          by (insert p k abc, auto)
      } note first = this
      note id' = id1 id2 id3
      have dist: "(a k + b k) $ p k = a k $ p k + b k $ p k"  
        by (rule index_add_vec(1), insert p k abc, force)
      show "(i?n. ?m a c p i) + (i?n. ?m b c p i) = (i?n. ?m ?ab c p i)"
        unfolding id' id first[of a, OF abc(1)] first[of b, OF abc(2)] first[of ?ab, OF ab] dist
        by (rule distrib_right[symmetric])
    qed 
  qed 
  also have " = ?lhs" unfolding det_def by simp
  finally show ?thesis by simp
qed


lemma det_linear_row_finsum:
  assumes fS: "finite S" and c: "c  {0..<n}  carrier_vec n" and k: "k < n"
  and a: "a k  S  carrier_vec n"
  shows "det (matr n n (λ i. if i = k then finsum_vec TYPE('a :: comm_ring_1) n (a i) S else c i)) =
    sum (λj. det (matr n n (λ i. if i = k then a  i j else c i))) S"
proof -
  let ?sum = "finsum_vec TYPE('a) n"
  show ?thesis using a
  proof (induct rule: finite_induct[OF fS])
    case 1
    show ?case
      by (simp, unfold finsum_vec_empty, rule det_row_0[OF k c])
  next
    case (2 x F)
    from 2(4) have ak: "a k  F  carrier_vec n" and akx: "a k x  carrier_vec n" by auto    
    {
      fix i
      note if_cong[OF refl finsum_vec_insert[OF 2(1-2)],
        of _ "a i" n "c i" "c i"]
    } note * = this
    show ?case
    proof (subst *)
      show "det (matr n n (λi. if i = k then a i x + ?sum (a i) F else c i)) =
        (jinsert x F. det (matr n n (λi. if i = k then a i j else c i)))"
      proof (subst det_row_add)
        show "det (matr n n (λi. if i = k then a i x else c i)) +
          det (matr n n (λi. if i = k then ?sum (a i) F else c i)) =
        (jinsert x F. det (matr n n (λi. if i = k then a i j else c i)))"
        unfolding 2(3)[OF ak] sum.insert[OF 2(1-2)] by simp
      qed (insert c k ak akx 2(1), 
        auto intro!: finsum_vec_closed)
    qed (insert akx ak, force+)
  qed
qed


lemma det_linear_rows_finsum_lemma:
  assumes fS: "finite S"
    and fT: "finite T" and c: "c  {0..<n}  carrier_vec n"
    and T: "T  {0 ..< n}"
    and a: "a  T  S  carrier_vec n"
  shows "det (matr n n (λ i. if i  T then finsum_vec TYPE('a :: comm_ring_1) n (a i) S else c i)) =
    sum (λf. det(matr n n (λ i. if i  T then a i (f i) else c i)))
      {f. (i  T. f i  S)  (i. i  T  f i = i)}"
proof -
  let ?sum = "finsum_vec TYPE('a) n"
  show ?thesis using fT c a T
  proof (induct T arbitrary: a c set: finite)
    case empty
    let ?f = "(λ i. i) :: nat  nat"
    have [simp]: "{f. i. f i = i} = {?f}" by auto    
    show ?case by simp
  next
    case (insert z T a c)
    hence z: "z < n" and azS: "a z  S  carrier_vec n" by auto
    let ?F = "λT. {f. (i  T. f i  S)  (i. i  T  f i = i)}"
    let ?h = "λ(y,g) i. if i = z then y else g i"
    let ?k = "λh. (h(z),(λi. if i = z then i else h i))"
    let ?s = "λ k a c f. det(matr n n (λ i. if i  T then a i (f i) else c i))"
    let ?c = "λj i. if i = z then a i j else c i"
    have thif: "a b c d. (if a  b then c else d) = (if a then c else if b then c else d)"
      by simp
    have thif2: "a b c d e. (if a then b else if c then d else e) =
       (if c then (if a then b else d) else (if a then b else e))"
      by simp
    from z  T have nz: "i. i  T  i = z  False"
      by auto
    from insert have c: " i. i < n  c i  carrier_vec n" by auto
    have fin: "finite {f. (iT. f i  S)  (i. i  T  f i = i)}"
      by (rule finite_bounded_functions[OF fS insert(1)])
    have "det (matr n n (λ i. if i  insert z T then ?sum (a i) S else c i)) =
      det (matr n n (λ i. if i = z then ?sum (a i) S else if i  T then ?sum (a i) S else c i))"
      unfolding insert_iff thif ..
    also have " = (jS. det (matr n n (λ i. if i  T then ?sum (a i) S else if i = z then a i j else c i)))"
      apply (subst det_linear_row_finsum[OF fS _ z])
      prefer 3
      apply (subst thif2)
      using nz
      apply (simp cong del: if_weak_cong cong add: if_cong)
      apply (insert azS c fS insert(5), (force intro!: finsum_vec_closed)+)
      done
    also have " = (sum (λ (j, f). det (matr n n (λ i. if i  T then a i (f i)
      else if i = z then a i j
      else c i))) (S × ?F T))"
      unfolding sum.cartesian_product[symmetric]
      by (rule sum.cong[OF refl], subst insert.hyps(3), 
        insert azS c fin z insert(5-6), auto)
    finally have tha:
      "det (matr n n (λ i. if i  insert z T then ?sum (a i) S else c i)) =
       (sum (λ (j, f). det (matr n n (λ i. if i  T then a i (f i)
          else if i = z then a i j
          else c i))) (S × ?F T))" .                
    show ?case unfolding tha
      by (rule sum.reindex_bij_witness[where i="?k" and j="?h"], insert z  T
      azS c fS insert(5-6) z fin, 
      auto intro!: arg_cong[of _ _ det])
  qed
qed

lemma det_linear_rows_sum:
  assumes fS: "finite S"
  and a: "a  {0..<n}  S  carrier_vec n"
  shows "det (matr n n (λ i. finsum_vec TYPE('a :: comm_ring_1) n (a i) S)) =
    sum (λf. det (matr n n (λ i. a i (f i)))) 
    {f. (i{0..<n}. f i  S)  (i. i  {0..<n}  f i = i)}"
proof -
  let ?T = "{0..<n}"
  have fT: "finite ?T" by auto
  have th0: "x y. matr n n (λ i. if i  ?T then x i else y i) = matr n n (λ i. x i)"
    by (rule eq_rowI, auto)
  have c: "(λ _. 0v n)  ?T  carrier_vec n" by auto
  show ?thesis
    by (rule det_linear_rows_finsum_lemma[OF fS fT c subset_refl a, unfolded th0])
qed

lemma det_rows_mul:
  assumes a: "a  {0..<n}  carrier_vec n"
  shows "det(matr n n (λ i. c i v a i)) =
    prod c {0..<n} * det(matr n n (λ i. a i))"
proof -
  have A: "matr n n (λ i. c i v a i)  carrier_mat n n" 
  and A': "matr n n (λ i. a i)  carrier_mat n n" using a unfolding carrier_mat_def by auto
  show ?thesis unfolding det_def'[OF A] det_def'[OF A']
  proof (rule trans[OF sum.cong sum_distrib_left[symmetric]])
    fix p
    assume p: "p  {p. p permutes {0..<n}}"
    have id: "(ia{0..<n}. matr n n (λi. c i v a i) $$ (ia, p ia))
      = prod c {0..<n} * (ia{0..<n}. matr n n a $$ (ia, p ia))"
      unfolding prod.distrib[symmetric]
      by (rule prod.cong, insert p a, force+)
    show "signof p * (ia{0..<n}. matr n n (λi. c i v a i) $$ (ia, p ia)) =
           prod c {0..<n} * (signof p * (ia{0..<n}. matr n n a $$ (ia, p ia)))"
      unfolding id by auto
  qed simp
qed


lemma mat_mul_finsum_alt:
  assumes A: "A  carrier_mat nr n" and B: "B  carrier_mat n nc"
  shows "A * B = matr nr nc (λ i. finsum_vec TYPE('a :: semiring_0) nc (λk. A $$ (i,k) v row B k) {0 ..< n})"
  by (rule eq_matI, insert A B, auto, subst index_finsum_vec, auto simp: scalar_prod_def intro: sum.cong)


lemma det_mult:
  assumes A: "A  carrier_mat n n" and B: "B  carrier_mat n n"
  shows "det (A * B) = det A * det (B :: 'a :: comm_ring_1 mat)"
proof -
  let ?U = "{0 ..< n}"
  let ?F = "{f. (i ?U. f i  ?U)  (i. i  ?U  f i = i)}"
  let ?PU = "{p. p permutes ?U}"
  have fU: "finite ?U" 
    by blast
  have fF: "finite ?F"
    by (rule finite_bounded_functions, auto)
  {
    fix p
    assume p: "p permutes ?U"
    have "p  ?F" unfolding mem_Collect_eq permutes_in_image[OF p]
      using p[unfolded permutes_def] by simp
  }
  then have PUF: "?PU  ?F" by blast
  {
    fix f
    assume fPU: "f  ?F - ?PU"
    have fUU: "f ` ?U  ?U"
      using fPU by auto
    from fPU have f: "i  ?U. f i  ?U" "i. i  ?U  f i = i" "¬(y. ∃!x. f x = y)"
      unfolding permutes_def by auto
    let ?A = "matr n n (λ i. A $$ (i, f i) v row B (f i))"
    let ?B = "matr n n (λ i. row B (f i))"
    have B': "?B  carrier_mat n n"
      by (intro mat_row_carrierI)
    {
      assume fi: "inj_on f ?U"
      from inj_on_nat_permutes[OF fi] f
      have "f permutes ?U" by auto
      with fPU have False by simp
    } 
    hence fni: "¬ inj_on f ?U" by auto
    then obtain i j where ij: "f i = f j" "i  j" "i < n" "j < n"
      unfolding inj_on_def by auto
    from ij
    have rth: "row ?B i = row ?B j" by auto
    have "det ?A = 0" 
      by (subst det_rows_mul, unfold det_identical_rows[OF B' ij(2-4) rth], insert f A B, auto)
  }
  then have zth: " f. f  ?F - ?PU  det (matr n n (λ i. A $$ (i, f i) v row B (f i))) = 0"
    by simp
  {
    fix p
    assume pU: "p  ?PU"
    from pU have p: "p permutes ?U"
      by blast
    let ?s = "λp. (signof p) :: 'a"
    let ?f = "λq. ?s p * ( i ?U. A $$ (i,p i)) * (?s q * (i ?U. B $$ (i, q i)))"
    have "(sum (λq. ?s q *
        (i ?U. matr n n (λ i. A $$ (i, p i) v row B (p i)) $$ (i, q i))) ?PU) =
      (sum (λq. ?s p * ( i ?U. A $$ (i,p i)) * (?s q * ( i ?U. B $$ (i, q i)))) ?PU)"
      unfolding sum_permutations_compose_right[OF permutes_inv[OF p], of ?f]
    proof (rule sum.cong[OF refl])
      fix q
      assume "q  {q. q permutes ?U}"
      hence q: "q permutes ?U" by simp
      from p q have pp: "permutation p" and pq: "permutation q"
        unfolding permutation_permutes by auto
      note sign = signof_compose[OF q permutes_inv[OF p], unfolded signof_inv[OF fU p]]
      let ?inv = "Hilbert_Choice.inv"
      have th001: "prod (λi. B$$ (i, q (?inv p i))) ?U = prod ((λi. B$$ (i, q (?inv p i)))  p) ?U"
        by (rule prod.permute[OF p])
      have thp: "prod (λi. matr n n (λ i. A$$(i,p i) v row B (p i)) $$ (i, q i)) ?U =
        prod (λi. A$$(i,p i)) ?U * prod (λi. B$$ (i, q (?inv p i))) ?U"
        unfolding th001 o_def permutes_inverses[OF p]
        by (subst prod.distrib[symmetric], insert A p q B, auto intro: prod.cong)
      define AA where "AA = (i?U. A $$ (i, p i))"
      define BB where "BB = (ia{0..<n}. B $$ (ia, q (?inv p ia)))"
      have "?s q * (ia{0..<n}. matr n n (λi. A $$ (i, p i) v row B (p i)) $$ (ia, q ia)) =
         ?s p * (i{0..<n}. A $$ (i, p i)) * (?s (q  ?inv p) * (ia{0..<n}. B $$ (ia, q (?inv p ia))))"
        unfolding sign thp
        unfolding AA_def[symmetric] BB_def[symmetric]
        by (simp add: ac_simps signof_def)
      thus "?s q * (i = 0..<n. matr n n (λi. A $$ (i, p i) v row B (p i)) $$ (i, q i)) =
         ?s p * (i = 0..<n. A $$ (i, p i)) *
         (?s (q  ?inv p) * (i = 0..<n. B $$ (i, (q  ?inv p) i)))" by simp
    qed 
  } note * = this
  have th2: "sum (λf. det (matr n n (λ i. A$$(i,f i) v row B (f i)))) ?PU = det A * det B"
    unfolding det_def'[OF A] det_def'[OF B] det_def'[OF mat_row_carrierI]
    unfolding sum_product dim_row_mat
    by (rule sum.cong, insert A, force, subst *, insert A B, auto)
  let ?f = "λ f. det (matr n n (λ i. A $$ (i, f i) v row B (f i)))"
  have "det (A * B) = sum ?f ?F"
    unfolding mat_mul_finsum_alt[OF A B]
    by (rule det_linear_rows_sum[OF fU], insert A B, auto)
  also have " = sum ?f ((?F - ?PU)  (?F  ?PU))"
    by (rule arg_cong[where f = "sum ?f"], auto)
  also have " = sum ?f (?F - ?PU) + sum ?f (?F  ?PU)"
    by (rule sum.union_disjoint, insert A B finite_bounded_functions[OF fU fU], auto)
  also have "sum ?f (?F - ?PU) = 0"
    by (rule sum.neutral, insert zth, auto)
  also have "?F  ?PU = ?PU" unfolding permutes_def by fastforce
  also have "sum ?f ?PU = det A * det B"
    unfolding th2 ..
  finally show ?thesis by simp
qed

lemma unit_imp_det_non_zero: assumes "A  Units (ring_mat TYPE('a :: comm_ring_1) n b)"
   shows "det A  0"
proof -
  from assms[unfolded Units_def ring_mat_def]
  obtain B where A: "A  carrier_mat n n" and B: "B  carrier_mat n n" and BA: "B * A = 1m n" by auto
  from arg_cong[OF BA, of det, unfolded det_mult[OF B A] det_one]
  show ?thesis by auto
qed

text ‹The following proof is based on the Gauss-Jordan algorithm.›

lemma det_non_zero_imp_unit: assumes A: "A  carrier_mat n n"
  and dA: "det A  (0 :: 'a :: field)"
  shows "A  Units (ring_mat TYPE('a) n b)"
proof (rule ccontr)
  let ?g = "gauss_jordan A (0m n 0)"
  let ?B = "fst ?g"
  obtain B C where B: "?g = (B,C)" by (cases ?g)
  assume "¬ ?thesis"
  from this[unfolded gauss_jordan_check_invertable[OF A zero_carrier_mat[of n 0]] B]
  have "B  1m n" by auto
  with row_echelon_form_imp_1_or_0_row[OF gauss_jordan_carrier(1)[OF A _ B] gauss_jordan_row_echelon[OF A B], of 0]
  have n: "0 < n" and row: "row B (n - 1) = 0v n" by auto
  let ?n = "n - 1"
  from n have n1: "?n < n" by auto
  from gauss_jordan_transform[OF A _ B, of 0 b] obtain P
    where P: "PUnits (ring_mat TYPE('a) n b)" and PA: "B = P * A" by auto
  from unit_imp_det_non_zero[OF P] have dP: "det P  0" by auto
  from P have P: "P  carrier_mat n n" unfolding Units_def ring_mat_def by auto
  from det_mult[OF P A] dP dA have "det B  0" unfolding PA by simp
  also have "det B = 0" 
  proof -
    from gauss_jordan_carrier[OF A _ B, of 0] have B: "B  carrier_mat n n" by auto
    {
      fix j
      assume j: "j < n"
      from index_row(1)[symmetric, of ?n B j, unfolded row] B
      have "B $$ (?n, j) = 0" using B n j by auto
    }
    hence "B = matr n n (λi. if i = ?n then 0v n else row B i)"
      by (intro eq_matI, insert B, auto)
    also have "det  = 0"
      by (rule det_row_0[OF n1], insert B, auto)
    finally show "det B = 0" .
  qed 
  finally show False by simp
qed

lemma mat_mult_left_right_inverse: assumes A: "(A :: 'a :: field mat)  carrier_mat n n" 
  and B: "B  carrier_mat n n" and AB: "A * B = 1m n"
  shows "B * A = 1m n"
proof -
  let ?R = "ring_mat TYPE('a) n undefined"
  from det_mult[OF A B, unfolded AB] have "det A  0" "det B  0" by auto
  from det_non_zero_imp_unit[OF A this(1)] det_non_zero_imp_unit[OF B this(2)]  
  have U: "A  Units ?R" "B  Units ?R" .
  interpret ring ?R by (rule ring_mat)
  from Units_inv_comm[unfolded ring_mat_simps, OF AB U] show ?thesis .
qed

lemma det_zero_imp_zero_row: assumes A: "(A :: 'a :: field mat)  carrier_mat n n"
  and det: "det A = 0"
  shows " P. P  Units (ring_mat TYPE('a) n b)  row (P * A) (n - 1) = 0v n  0 < n
     row_echelon_form (P * A)"
proof -
  let ?R = "ring_mat TYPE('a) n b"
  let ?U = "Units ?R"
  interpret m: ring ?R by (rule ring_mat)
  let ?g = "gauss_jordan A A"
  obtain A' B' where g: "?g = (A', B')" by (cases ?g)
  from det unit_imp_det_non_zero[of A n b] have AU: "A  ?U" by auto
  with gauss_jordan_inverse_one_direction(1)[OF A A, of _ b]
  have A'1: "A'  1m n" using g by auto
  from gauss_jordan_carrier(1)[OF A A g] have A': "A'  carrier_mat n n" by auto
  from gauss_jordan_row_echelon[OF A g] have re: "row_echelon_form A'" .
  from row_echelon_form_imp_1_or_0_row[OF A' this] A'1
  have n: "0 < n" and row: "row A' (n - 1) = 0v n" by auto
  from gauss_jordan_transform[OF A A g, of b] obtain P
    where P: "P  ?U" and A': "A' = P * A" by auto
  thus ?thesis using n row re by auto
qed

lemma det_0_iff_vec_prod_zero_field: assumes A: "(A :: 'a :: field mat)  carrier_mat n n"
  shows "det A = 0  ( v. v  carrier_vec n  v  0v n  A *v v = 0v n)" (is "?l = ( v. ?P v)")
proof -
  let ?R = "ring_mat TYPE('a) n ()"
  let ?U = "Units ?R"
  interpret m: ring ?R by (rule ring_mat)
  show ?thesis
  proof (cases "det A = 0")
    case False
    from det_non_zero_imp_unit[OF A this, of "()"]
    have "A  ?U" .
    then obtain B where unit: "B * A = 1m n" and B: "B  carrier_mat n n"
      unfolding Units_def ring_mat_def by auto
    {
      fix v
      assume "?P v"
      hence v: "v  carrier_vec n" "v  0v n" "A *v v = 0v n" by auto
      have "v = (B * A) *v v" using v B unfolding unit by auto
      also have " = B *v (A *v v)" using B A v by simp
      also have " = B *v 0v n" unfolding v ..
      also have " = 0v n" using B by auto
      finally have False using v by simp
    }
    with False show ?thesis by blast
  next
    case True
    let ?n = "n - 1"
    from det_zero_imp_zero_row[OF A True, of "()"]
    obtain P where PU: "P  ?U" and row: "row (P * A) ?n = 0v n" and n: "0 < n" "?n < n"
      and re: "row_echelon_form (P * A)" by auto
    define PA where "PA = P * A"
    note row = row[folded PA_def]
    note re = re[folded PA_def]
    from PU obtain Q where P: "P  carrier_mat n n" and Q: "Q  carrier_mat n n"
      and unit: "Q * P = 1m n" "P * Q =  1m n" unfolding Units_def ring_mat_def by auto    
    from P A have PA: "PA  carrier_mat n n" and dimPA: "dim_row PA = n" unfolding PA_def by auto
    from re[unfolded row_echelon_form_def] obtain p where p: "pivot_fun PA p n" using PA by auto 
    note piv = pivot_positions[OF PA p]
    note pivot = pivot_funD[OF dimPA p n(2)]
    {
      assume "p ?n < n"
      with pivot(4)[OF this] n arg_cong[OF row, of "λ v. v $ p ?n"] have False using PA by auto
    }
    with pivot(1) have pn: "p ?n = n" by fastforce
    with piv(1) have "set (pivot_positions PA)   {(i, p i) |i. i < n  p i  n} - {(?n,p ?n)}" by auto
    also have "  {(i, p i) | i. i < ?n}" using n by force
    finally have "card (set (pivot_positions PA))  card {(i, p i) | i. i < ?n}"
      by (intro card_mono, auto)
    also have "{(i, p i) | i. i < ?n} = (λ i. (i, p i)) ` {0 ..< ?n}" by auto
    also have "card  = card {0 ..< ?n}" by (rule card_image, auto simp: inj_on_def)
    also have " < n" using n by simp
    finally have "card (set (pivot_positions PA)) < n" .
    hence "card (snd ` (set (pivot_positions PA))) < n" 
      using card_image_le[OF finite_set, of snd "pivot_positions PA"] by auto
    hence neq: "snd ` (set (pivot_positions PA))  {0 ..< n}" by auto
    from find_base_vector[OF re PA neq] obtain v where v: "v  carrier_vec n"
      and v0: "v  0v n" and pav: "PA *v v = 0v n" by auto
    have "A *v v = Q * P *v (A *v v)" unfolding unit using A v by auto
    also have " = Q *v (PA *v v)" unfolding PA_def using Q P A v by auto
    also have "PA *v v = 0v n" unfolding pav ..
    also have "Q *v 0v n = 0v n" using Q by auto
    finally have Av: "A *v v = 0v n" by auto
    show ?thesis unfolding True using Av v0 v by auto
  qed
qed

text ‹In order to get the result for integral domains, we embed the domain in its
  fraction field, and then apply the result for fields.›
lemma det_0_iff_vec_prod_zero: assumes A: "(A :: 'a :: idom mat)  carrier_mat n n"
  shows "det A = 0  ( v. v  carrier_vec n  v  0v n  A *v v = 0v n)"
proof -
  let ?h = "to_fract :: 'a  'a fract"
  let ?A = "map_mat ?h A"
  have A': "?A  carrier_mat n n" using A by auto
  interpret inj_comm_ring_hom ?h by (unfold_locales, auto)
  have "(det A = 0) = (?h (det A) = ?h 0)" by auto
  also have " = (det ?A = 0)" unfolding hom_zero hom_det ..
  also have " = (( v. v  carrier_vec n  v  0v n  ?A *v v = 0v n))"
    unfolding det_0_iff_vec_prod_zero_field[OF A'] ..
  also have " = (( v. v  carrier_vec n  v  0v n  A *v v = 0v n))" (is "?l = ?r")
  proof
    assume ?r
    then obtain v where v: "v  carrier_vec n" "v  0v n" "A *v v = 0v n" by auto
    show ?l
      by (rule exI[of _ "map_vec ?h v"], insert v, auto simp: mult_mat_vec_hom[symmetric, OF A v(1)])
  next
    assume ?l
    then obtain v where v: "v  carrier_vec n" and v0: "v  0v n" and Av: "?A *v v = 0v n" by auto
    have " i.  a b. v $ i = Fraction_Field.Fract a b  b  0" using Fract_cases[of "v $ i" for i] by metis
    from choice[OF this] obtain a where " i.  b. v $ i = Fraction_Field.Fract (a i) b  b  0" by metis
    from choice[OF this] obtain b where vi: " i. v $ i = Fraction_Field.Fract (a i) (b i)" and bi: " i. b i  0" by auto
    define m where "m = prod_list (map b [0..<n])"
    let ?m = "?h m"
    have m0: "m  0" unfolding m_def hom_0_iff prod_list_zero_iff using bi by auto
    from v0[unfolded vec_eq_iff] v obtain i where i: "i < n" "v $ i  0" by auto
    {
      fix i
      assume "i < n"
      hence "b i  set (map b [0 ..< n])" by auto
      from split_list[OF this]
        obtain ys zs where "map b [0..<n] = ys @ b i # zs" by auto
      hence "b i dvd m" unfolding m_def by auto
      then obtain c where "m = b i * c" ..
      hence "?m * v $ i = ?h (a i * c)" unfolding vi using bi[of i]
        by (simp add: eq_fract to_fract_def)
      hence " c. ?m * v $ i = ?h c" ..
    }
    hence " i.  c. i < n  ?m * v $ i = ?h c" by auto
    from choice[OF this] obtain c where c: " i. i < n  ?m * v $ i = ?h (c i)" by auto
    define w where "w = vec n c"
    have w: "w  carrier_vec n" unfolding w_def by simp
    have mvw: "?m v v = map_vec ?h w" unfolding w_def using c v
      by (intro eq_vecI, auto)
    with m0 i c[OF i(1)] have "w $ i  0" unfolding w_def by auto
    with i w have w0: "w  0v n" by auto
    from arg_cong[OF Av, of "λ v. ?m v v"]
    have "?m v (?A *v v) = map_vec ?h (0v n)" by auto
    also have "?m v (?A *v v) = ?A *v (?m v v)" using A v by auto
    also have " = ?A *v (map_vec ?h w)" unfolding mvw ..
    also have " = map_vec ?h (A *v w)" unfolding mult_mat_vec_hom[OF A w] ..
    finally have "A *v w = 0v n" by (rule vec_hom_inj)
    with w w0 show ?r by blast
  qed
  finally show ?thesis .
qed

lemma det_0_negate: assumes  A: "(A :: 'a :: field mat)  carrier_mat n n"
  shows "(det (- A) = 0) = (det A = 0)"
proof -
  from A have mA: "- A  carrier_mat n n" by auto
  {
    fix v :: "'a vec"
    assume v: "v  carrier_vec n"
    hence Av: "A *v v  carrier_vec n" using A by auto
    have id: "- A *v v = - (A *v v)" using v A by simp
    have "(- A *v v = 0v n) = (A *v v = 0v n)" unfolding id 
      unfolding uminus_zero_vec_eq[OF Av] ..
  }
  thus ?thesis unfolding det_0_iff_vec_prod_zero[OF A] det_0_iff_vec_prod_zero[OF mA] by auto
qed
  
lemma det_multrow: 
  assumes k: "k < n" and A: "A  carrier_mat n n"
  shows "det (multrow k a A) = a * det A"
proof -
  have "multrow k a A = multrow_mat n k a * A"
    by (rule multrow_mat[OF A])
  also have "det (multrow_mat n k a * A) = det (multrow_mat n k a) * det A"
    by (rule det_mult[OF _ A], auto)
  also have "det (multrow_mat n k a) = a"
    by (rule det_multrow_mat[OF k])
  finally show ?thesis .
qed

lemma det_multrow_div:
  assumes k: "k < n" and A: "A  carrier_mat n n" and a0: "a  0"
  shows "det (multrow k a A :: 'a :: idom_divide mat) div a = det A"
proof -
  have "det (multrow k a A) div a = a * det A div a" using k A
    by (simp add: det_multrow)
  also have "... = det A" using a0 by auto
  finally show ?thesis.
qed

lemma det_addrow: 
  assumes l: "l < n" and k: "k  l" and A: "A  carrier_mat n n"
  shows "det (addrow a k l A) = det A"
proof -
  have "addrow a k l A = addrow_mat n a k l * A"
    by (rule addrow_mat[OF A l])
  also have "det (addrow_mat n a k l * A) = det (addrow_mat n a k l) * det A"
    by (rule det_mult[OF _ A], auto)
  also have "det (addrow_mat n a k l) = 1"
    by (rule det_addrow_mat[OF k])
  finally show ?thesis using A by simp
qed

lemma det_swaprows: 
  assumes *: "k < n" "l < n" and k: "k  l" and A: "A  carrier_mat n n"
  shows "det (swaprows k l A) = - det A"
proof -
  have "swaprows k l A = swaprows_mat n k l * A"
    by (rule swaprows_mat[OF A *])
  also have "det (swaprows_mat n k l * A) = det (swaprows_mat n k l) * det A"
    by (rule det_mult[OF _ A], insert A, auto)
  also have "det (swaprows_mat n k l) = - 1"
    by (rule det_swaprows_mat[OF * k])
  finally show ?thesis using A by simp
qed

lemma det_similar: assumes "similar_mat A B" 
  shows "det A = det B"
proof -
  from similar_matD[OF assms] obtain n P Q where
  carr: "{A, B, P, Q}  carrier_mat n n" (is "_  ?C")
  and PQ: "P * Q = 1m n" 
  and AB: "A = P * B * Q" by blast
  hence A: "A  ?C" and B: "B  ?C" and P: "P  ?C" and Q: "Q  ?C" by auto  
  from det_mult[OF P Q, unfolded PQ] have PQ: "det P * det Q = 1" by auto
  from det_mult[OF _ Q, of "P * B", unfolded det_mult[OF P B] AB[symmetric]] P B
  have "det A = det P * det B * det Q" by auto
  also have " = (det P * det Q) * det B" by (simp add: ac_simps)
  also have " = det B" unfolding PQ by simp
  finally show ?thesis .
qed

lemma det_four_block_mat_upper_right_zero_col: assumes A1: "A1  carrier_mat n n"
  and A20: "A2 = (0m n 1)" and A3: "A3  carrier_mat 1 n"
  and A4: "A4  carrier_mat 1 1"
  shows "det (four_block_mat A1 A2 A3 A4) = det A1 * det A4" (is "det ?A = _")
proof -
  let ?A = "four_block_mat A1 A2 A3 A4"
  from A20 have A2: "A2  carrier_mat n 1" by auto  
  define A where "A = ?A"
  from four_block_carrier_mat[OF A1 A4] A1
  have A: "A  carrier_mat (Suc n) (Suc n)" and dim: "dim_row A1 = n" unfolding A_def by auto
  let ?Pn = "λ p. p permutes {0 ..< n}"
  let ?Psn = "λ p. p permutes {0 ..< Suc n}"
  let ?perm = "{p. ?Psn p}"
  let ?permn = "{p. ?Pn p}"
  let ?prod = "λ p. signof p * (i = 0..<Suc n. A $$ (p i, i))"
  let ?prod' = "λ p. A $$ (p n, n) * signof p * (i = 0..<n. A $$ (p i, i))"
  let ?prod'' = "λ p. signof p * (i = 0..< n. A $$ (p i, i))"
  let ?prod''' = "λ p. signof p * (i = 0..< n. A1 $$ (p i, i))"
  let ?p0 = "{p. p 0 = 0}"
  have [simp]: "{0..<Suc n} - {n} = {0..< n}" by auto
  {
    fix p
    assume "?Psn p"
    have "?prod p = signof p * (A $$ (p n, n) * ( i  {0..< n}. A $$ (p i, i)))"
      by (subst prod.remove[of _ n], auto)
    also have " = A $$ (p n, n) * signof p * ( i  {0..< n}. A $$ (p i, i))" by simp
    finally have "?prod p = ?prod' p" .
  } note prod_id = this
  define prod' where "prod' = ?prod'"
  {
    fix i q
    assume i: "i  {0..< n}" "q permutes {0 ..< n}"
    hence "Fun.swap n i id (q n) < n" 
      unfolding permutes_def by auto
    hence "A $$ (Fun.swap n i id (q n), n) = 0"
      unfolding A_def using A1 A20 A3 A4 by auto
    hence "prod' (Fun.swap n i id  q) = 0"
      unfolding prod'_def by simp 
  } note zero = this
  have cong: " a b c. b = c  a * b = a * c" by auto
  have "det ?A = sum ?prod ?perm"
    unfolding A_def[symmetric] using mat_det_left_def[OF A] A by simp
  also have " = sum prod' ?perm" unfolding prod'_def
    by (rule sum.cong[OF refl], insert prod_id, auto)
  also have "{0 ..< Suc n} = insert n {0 ..< n}" by auto
  also have "sum prod' {p. p permutes } = 
    (iinsert n {0..<n}. q?permn. prod' (Fun.swap n i id  q))"
    by (subst sum_over_permutations_insert, auto)
  also have " = (q?permn. prod' q) +
    (i{0..<n}. q?permn. prod' (Fun.swap n i id  q))"
    by (subst sum.insert, auto)
  also have "(i{0..<n}. q?permn. prod' (Fun.swap n i id  q)) = 0"
    by (rule sum.neutral, intro ballI, rule sum.neutral, intro ballI, rule zero, auto)
  also have "(q ?permn. prod' q) = A $$ (n,n) * (q ?permn. ?prod'' q)"
    unfolding prod'_def
    by (subst sum_distrib_left, rule sum.cong[OF refl], auto simp: permutes_def ac_simps)
  also have "A $$ (n,n) = A4 $$ (0,0)" unfolding A_def using A1 A2 A3 A4 by auto
  also have "(q ?permn. ?prod'' q) = (q ?permn. ?prod''' q)" 
    by (rule sum.cong[OF refl], rule cong, rule prod.cong,
    insert A1 A2 A3 A4, auto simp: permutes_def A_def)
  also have " = det A1"
    unfolding mat_det_left_def[OF A1] dim by auto
  also have "A4 $$ (0,0) = det A4"
    using A4 unfolding det_def[of A4] by (auto simp: signof_def sign_def)
  finally show ?thesis by simp
qed

lemma det_swap_initial_rows: assumes A: "A  carrier_mat m m" 
  and lt: "k + n  m" 
  shows "det A = (- 1) ^ (k * n) *
    det (mat m m (λ(i, j). A $$ (if i < n then i + k else if i < k + n then i - n else i, j)))" 
proof -
  define sw where "sw = (λ (A :: 'a mat) xs. fold (λ (i,j). swaprows i j) xs A)"
  have dim_sw[simp]: "dim_row (sw A xs) = dim_row A" "dim_col (sw A xs) = dim_col A" for xs A
    unfolding sw_def by (induct xs arbitrary: A, auto)
  {
    fix xs and A :: "'a mat"
    assume "dim_row A = dim_col A" " i j. (i,j)  set xs  i < dim_col A  j < dim_col A  i  j"
    hence "det (sw A xs) = (-1)^(length xs) * det A"
      unfolding sw_def
    proof (induct xs arbitrary: A)
      case (Cons xy xs A)
      obtain x y where xy: "xy = (x,y)" by force
      from Cons(3)[unfolded xy, of x y] Cons(2)
      have [simp]: "det (swaprows x y A) = - det A"
        by (intro det_swaprows, auto)
      show ?case unfolding xy by (simp, insert Cons(2-), (subst Cons(1), auto)+)
    qed simp
  } note sw = this
  define swb where "swb = (λ A i n. sw A (map (λ j. (j,Suc j)) [i ..< i + n]))"
  {
    fix k n and A :: "'a mat"
    assume k_n: "k + n < dim_row A"
    hence "swb A k n = mat (dim_row A) (dim_col A) (λ (i,j). let r = 
      (if i < k  i > k + n then i else if i = k + n then k else Suc i)
      in A $$ (r,j))"
    proof (induct n)
      case 0
      show ?case unfolding swb_def sw_def by (rule eq_matI, auto)
    next
      case (Suc n)
      hence dim: "k + n < dim_row A" by auto
      have id: "swb A k (Suc n) = swaprows (k + n) (Suc k + n) (swb A k n)" unfolding swb_def sw_def by simp
      show ?case unfolding id Suc(1)[OF dim]
        by (rule eq_matI, insert Suc(2), auto)
    qed
  } note swb = this
  define swbl where "swbl = (λ A k n. fold (λ i A. swb A i n) (rev [0 ..< k]) A)"
  {
    fix k n and A :: "'a mat"
    assume k_n: "k + n  dim_row A"
    hence "swbl A k n = mat (dim_row A) (dim_col A) (λ (i,j). let r = 
      (if i < n then i + k else if i < k + n then i - n else i)
      in A $$ (r,j))"
    proof (induct k arbitrary: A)
      case 0
      thus ?case unfolding swbl_def by (intro eq_matI, auto simp: swb)
    next
      case (Suc k)
      hence dim: "k + n < dim_row A" by auto
      have id: "swbl A (Suc k) n = swbl (swb A k n) k n" unfolding swbl_def by simp
      show ?case unfolding id swb[OF dim]
        by (subst Suc(1), insert dim, force, intro eq_matI, auto simp: less_Suc_eq_le) 
    qed
  } note swbl = this
  {
    fix k n and A :: "'a mat"
    assume k_n: "k + n  dim_col A" "dim_row A = dim_col A" 
    hence "det (swbl A k n) = (-1)^(k*n) * det A" 
    proof (induct k arbitrary: A)
      case 0
      thus ?case unfolding swbl_def by auto
    next
      case (Suc k)
      hence dim: "k + n < dim_row A" by auto
      have id: "swbl A (Suc k) n = swbl (swb A k n) k n" unfolding swbl_def by simp
      have det: "det (swb A k n) = (-1)^n * det A" unfolding swb_def
        by (subst sw, insert Suc(2-), auto)
      show ?case unfolding id 
        by (subst Suc(1), insert Suc(2-), auto simp: det, auto simp: swb power_add)
    qed
  } note det_swbl = this
  from assms have dim: "dim_row A = dim_col A" "k + n  dim_col A" "k + n  dim_row A" "dim_col A = m" by auto
  from arg_cong[OF det_swbl[OF dim(2,1), unfolded swbl[OF dim(3)], unfolded Let_def dim], 
      of  "λ x. (-1)^(k*n) * x"]
  show ?thesis by simp
qed

lemma det_swap_rows: assumes A: "A  carrier_mat (k + n) (k + n)" 
  shows "det A = (-1)^(k * n) * det (mat (k + n) (k + n) (λ (i,j). 
    A $$ ((if i < k then i + n else i - k),j)))" 
proof -
  have le: "n + k  k + n" by simp
  show ?thesis unfolding det_swap_initial_rows[OF A le]
    by (intro arg_cong2[of _ _ _ _ "λ x y. ((-1)^x * det y)"], force, intro eq_matI, auto)
qed

lemma det_swap_final_rows: assumes A: "A  carrier_mat m m"
  and m: "m = l + k + n" 
  shows "det A = (- 1) ^ (k * n) *
    det (mat m m (λ(i, j). A $$ (if i < l then i else if i < l + n then i + k else i - n, j)))" 
    (is "_ = _ * det ?M")
proof -
  (* l k n -swap-rows→ k n l -swap-initial→ n k l -swap-rows→ l n k *)
  have m1: "m = (k + n) + l" using m by simp
  have m2: "k + n  m" using m by simp
  have m3: "m = l + (n + k)" using m by simp
  define M where "M = ?M" 
  let ?M1 = "mat m m (λ(i, j). A $$ (if i < k + n then i + l else i - (k + n), j))" 
  let ?M2 = "mat m m
          (λ(i, j). A $$ (if i < n then i + k + l else if i < k + n then i - n + l else i - (k + n), j))" 
  have M2: "?M2  carrier_mat m m" by auto
  have "det A = (- 1) ^ ((k + n) * l) * det ?M1" 
    unfolding det_swap_rows[OF A[unfolded m1]] m1[symmetric] by simp
  also have "det ?M1 = (- 1) ^ (k * n) * det ?M2"
    by (subst det_swap_initial_rows[OF _ m2], force, rule arg_cong[of _ _ "λ x. _ * det x"],
    rule eq_matI, auto simp: m)
  also have "det ?M2 = (- 1) ^ (l * (n + k)) * det M" 
    unfolding M_def det_swap_rows[OF M2[unfolded m3], folded m3]
    by (rule arg_cong[of _ _ "λ x. _ * det x"], rule eq_matI, auto simp: m)
  finally have "det A = (-1) ^ ((k + n) * l + (k * n) + l * (n + k)) * det M" (is "_ = ?b ^ _ * _")
    by (simp add: power_add)
  also have "(k + n) * l + (k * n) + l * (n + k) = 2 * (l * (n + k)) + k * n" by simp
  also have "?b ^  = ?b ^ (k * n)" by (simp add: power_add)
  finally show ?thesis unfolding M_def .
qed

lemma det_swap_final_cols: assumes A: "A  carrier_mat m m"
  and m: "m = l + k + n" 
  shows "det A = (- 1) ^ (k * n) *
    det (mat m m (λ(i, j). A $$ (i, if j < l then j else if j < l + n then j + k else j - n)))" 
proof -
  have "det A = det (AT)" unfolding det_transpose[OF A] ..
  also have " = (- 1) ^ (k * n) *
    det (mat m m (λ(i, j). AT $$ (if i < l then i else if i < l + n then i + k else i - n, j)))" 
    (is "_ = _ * det ?M")
    by (rule det_swap_final_rows[OF _ m], insert A, auto)
  also have "det ?M = det (?MT)" by (subst det_transpose, auto)
  also have "?MT = mat m m (λ(i, j). A $$ (i, if j < l then j else if j < l + n then j + k else j - n))" 
    unfolding transpose_mat_def using A m
    by (intro eq_matI, auto)
  finally show ?thesis .
qed

lemma det_swap_initial_cols: assumes A: "A  carrier_mat m m" 
  and lt: "k + n  m" 
  shows "det A = (- 1) ^ (k * n) *
    det (mat m m (λ(i, j). A $$ (i, if j < n then j + k else if j < k + n then j - n else j)))" 
proof -
  have "det A = det (AT)" unfolding det_transpose[OF A] ..
  also have " = (- 1) ^ (k * n) *
    det (mat m m (λ(j, i). AT $$ (if j < n then j + k else if j < k + n then j - n else j,i)))" 
    (is "_ = _ * det ?M")
    by (rule det_swap_initial_rows[OF _ lt], insert A, auto)
  also have "det ?M = det (?MT)" by (subst det_transpose, auto)
  also have "?MT = mat m m (λ(i, j). A $$ (i, if j < n then j + k else if j < k + n then j - n else j))" 
    unfolding transpose_mat_def using A lt
    by (intro eq_matI, auto)
  finally show ?thesis .
qed
  
lemma det_swap_cols: assumes A: "A  carrier_mat (k + n) (k + n)" 
  shows "det A = (-1)^(k * n) * det (mat (k + n) (k + n) (λ (i,j). 
   A $$ (i,(if j < k then j + n else j - k))))" (is "_ = _ * det ?B")
proof -
  have le: "n + k  k + n" by simp
  show ?thesis unfolding det_swap_initial_cols[OF A le]
    by (intro arg_cong2[of _ _ _ _ "λ x y. ((-1)^x * det y)"], force, intro eq_matI, auto)
qed  
  
lemma det_four_block_mat_upper_right_zero: fixes A1 :: "'a :: idom mat" 
  assumes A1: "A1  carrier_mat n n"
  and A20: "A2 = (0m n m)" and A3: "A3  carrier_mat m n"
  and A4: "A4  carrier_mat m m"  
shows "det (four_block_mat A1 A2 A3 A4) = det A1 * det A4" 
  using assms(2-)
proof (induct m arbitrary: A2 A3 A4)  
  case (0 A2 A3 A4)
  hence *: "four_block_mat A1 A2 A3 A4 = A1" using A1
    by (intro eq_matI, auto)
  from 0 have 4: "A4 = 1m 0" by auto
  show ?case unfolding * unfolding 4 by simp
next
  case (Suc m A2 A3 A4)    
  let ?m = "Suc m" 
  from Suc have A2: "A2  carrier_mat n ?m" by auto
  note A20 = Suc(2)
  note A34 = Suc(3-4)
  let ?A = "four_block_mat A1 A2 A3 A4" 
  let ?P = "λ B3 B4 v k. v  0  v * det ?A = det (four_block_mat A1 A2 B3 B4)
     v * det A4 = det B4  B3  carrier_mat ?m n  B4  carrier_mat ?m ?m  ( i < k. B4 $$ (i,m) = 0)" 
  have "k  m   B3 B4 v. ?P B3 B4 v k" for k
  proof (induct k)
    case 0
    have "?P A3 A4 1 0" using A34 by auto
    thus ?case by blast
  next
    case (Suc k)
    then obtain B3 B4 v where v: "v  0" and det: "v * det ?A = 
      det (four_block_mat A1 A2 B3 B4)" "v * det A4 = det B4" 
     and B3: "B3  carrier_mat ?m n" and B4: "B4  carrier_mat ?m ?m"  and 0: " i < k. B4 $$ (i,m) = 0" by auto
    show ?case
    proof (cases "B4 $$ (k,m) = 0")
      case True
      with 0 have 0: " i < Suc k. B4 $$ (i,m) = 0" using less_Suc_eq by auto
      with v det B3 B4 have "?P B3 B4 v (Suc k)" by auto
      thus ?thesis by blast
    next
      case Bk: False
      let ?k = "Suc k" 
      from Suc(2) have k: "k < ?m" "Suc k < ?m" "k  Suc k" by auto      
      show ?thesis
      proof (cases "B4 $$ (?k,m) = 0")
        case True
        let ?B4 = "swaprows k (Suc k) B4" 
        let ?B3 = "swaprows k (Suc k) B3" 
        let ?B = "four_block_mat A1 A2 ?B3 ?B4" 
        let ?v = "-v" 
        from det_swaprows[OF k B4] det have det1: "?v * det A4 = det ?B4" by simp
        from v have v: "?v  0" by auto
        from B3 have B3': "?B3  carrier_mat ?m n" by auto
        from B4 have B4': "?B4  carrier_mat ?m ?m" by auto
        have "?v * det ?A = - det (four_block_mat A1 A2 B3 B4)" using det by simp            
        also have " = det (swaprows (n + k) (n + ?k) (four_block_mat A1 A2 B3 B4))" 
          by (rule sym, rule det_swaprows[of _ "n + ?m"], insert A1 A2 B3 B4 k, auto)
        also have "swaprows (n + k) (n + ?k) (four_block_mat A1 A2 B3 B4) = ?B" 
        proof (rule eq_matI, unfold index_mat_four_block index_mat_swaprows, goal_cases)
          case (1 i j)
          show ?case
          proof (cases "i < n")
            case True
            thus ?thesis using 1(2) A1 A2 B3 B4 by simp
          next
            case False
            hence "i = n + (i - n)" by simp
            then obtain d where "i = n + d" by blast 
            thus ?thesis using 1 A1 A2 B3 B4 k(2) by simp
          qed
        qed auto
        finally have det2: "?v * det ?A = det ?B" .
        from True 0 B4 k(2) have " i < Suc k. ?B4 $$ (i,m) = 0" unfolding less_Suc_eq by auto
        with det1 det2 B3' B4' v have "?P ?B3 ?B4 ?v (Suc k)" by auto
        thus ?thesis by blast
      next
        case False
        let ?bk = "B4 $$ (?k,m)" 
        let ?b = "B4 $$ (k,m)" 
        let ?v = "v * ?bk" 
        let ?B3 = "addrow (- ?b) k ?k (multrow k ?bk B3)" 
        let ?B4 = "addrow (- ?b) k ?k (multrow k ?bk B4)" 
        have *: "det ?B4 = ?bk * det B4" 
          by (subst det_addrow[OF k(2-3)], force simp: B4, rule det_multrow[OF k(1) B4])
        with det(2)[symmetric] have det2: "?v * det A4 = det ?B4" by (auto simp: ac_simps)
        from 0 k(2) B4 have 0: " i < Suc k. ?B4 $$ (i,m) = 0" unfolding less_Suc_eq by auto
        from False v have v: "?v  0" by auto
        from B3 have B3': "?B3  carrier_mat ?m n" by auto
        from B4 have B4': "?B4  carrier_mat ?m ?m" by auto
        let ?B' = "multrow (n + k) ?bk (four_block_mat A1 A2 B3 B4)" 
        have B': "?B'  carrier_mat (n + ?m) (n + ?m)" using A1 A2 B3 B4 k by auto          
        let ?B = "four_block_mat A1 A2 ?B3 ?B4" 
        have "?v * det ?A = ?bk * det (four_block_mat A1 A2 B3 B4)" using det by simp            
        also have " = det (addrow (- ?b) (n + k) (n + ?k) ?B')" 
          by (subst det_addrow[OF _ _ B'], insert k(2), force, force, rule sym, rule det_multrow[of _ "n + ?m"],
          insert A1 A2 B3 B4 k, auto)
        also have "addrow (- ?b) (n + k) (n + ?k) ?B' = ?B" 
        proof (rule eq_matI, unfold index_mat_four_block index_mat_multrow index_mat_addrow, goal_cases)
          case (1 i j)
          show ?case
          proof (cases "i < n")
            case True
            thus ?thesis using 1(2) A1 A2 B3 B4 by simp
          next
            case False
            hence "i = n + (i - n)" by simp
            then obtain d where "i = n + d" by blast 
            thus ?thesis using 1 A1 A2 B3 B4 k(2) by simp
          qed
        qed auto
        finally have det1: "?v * det ?A = det ?B" .
        from det1 det2 B3' B4' v 0 have "?P ?B3 ?B4 ?v (Suc k)" by auto
        thus ?thesis by blast
      qed
    qed
  qed
  from this[OF le_refl] obtain B3 B4 v where P: "?P B3 B4 v m" by blast
  let ?B = "four_block_mat A1 A2 B3 B4" 
  from P have v: "v  0" and det: "v * det ?A = det ?B" "v * det A4 = det B4" 
    and B3: "B3  carrier_mat ?m n" and B4: "B4  carrier_mat ?m ?m" and 0: " i. i < m  B4 $$ (i, m) = 0" 
    by auto
  let ?A2 = "0m n m"  
  let ?A3 = "mat m n (λ ij. B3 $$ ij)" 
  let ?A4 = "mat m m (λ ij. B4 $$ ij)" 
  let ?B1 = "four_block_mat A1 ?A2 ?A3 ?A4" 
  let ?B2 = "0m (n + m) 1" 
  let ?B3 = "mat 1 (n + m) (λ (i,j). if j < n then B3 $$ (m,j) else B4 $$ (m,j - n))" 
  let ?B4 = "mat 1 1 (λ _. B4 $$ (m,m))" 
  have B44: "B4 = four_block_mat ?A4 (0m m 1) (mat 1 m (λ (i,j). B4 $$ (m,j))) ?B4" 
  proof (rule eq_matI, unfold index_mat_four_block dim_col_mat dim_row_mat, goal_cases)
    case (1 i j)
    hence [simp]: "¬ i < m  i = m" "¬ j < m  j = m" by auto
    from 1 show ?case using B4 0 by auto
  qed (insert B4, auto)
  have "?B = four_block_mat ?B1 ?B2 ?B3 ?B4"
  proof (rule eq_matI, unfold index_mat_four_block dim_col_mat dim_row_mat, goal_cases)
    case (1 i j)
    then consider (UL) "i < n + m" "j < n + m" | (UR) "i < n + m" "j = n + m" 
        | (LL) "i = n + m" "j < n + m" | (LR) "i = n + m" "j = n + m" using A1 by auto linarith
    thus ?case
    proof cases
      case UL
      hence [simp]: "¬ i < n  i - n < m" 
         "¬ j < n  j - n < m" "¬ j < n  j - n < Suc m" by auto
      from UL show ?thesis using A1 A20 B3 B4 by simp
    next
      case LL
      hence [simp]: "¬ j < n  j - n < m" "¬ j < n  j - n < Suc m" by auto
      from LL show ?thesis using A1 A2 B3 B4 by simp
    next
      case LR
      thus ?thesis using A1 A2 B3 B4 by simp
    next
      case UR
      hence [simp]: "¬ i < n  i - n < m" by auto
      from UR show ?thesis using A1 A20 0 B3 B4 by simp
    qed
  qed (insert B4, auto)
  hence "det ?B = det (four_block_mat ?B1 ?B2 ?B3 ?B4)" by simp
  also have " = det ?B1 * det ?B4" 
    by (rule det_four_block_mat_upper_right_zero_col[of _ "n + m"], insert A1 A2 B3 B4, auto)
  also have "det ?B1 = det A1 * det (mat m m (($$) B4))"  
    by (rule Suc(1), insert B3 B4, auto)
  also have " * det ?B4 = det A1 * (det (mat m m (($$) B4)) * det ?B4)" by simp
  also have "det (mat m m (($$) B4)) * det ?B4 = det B4"
    unfolding arg_cong[OF B44, of det] 
    by (subst det_four_block_mat_upper_right_zero_col[OF _ refl], auto)
  finally have id: "det ?B = det A1 * det B4" .
  from this[folded det] have "v * det ?A = v * (det A1 * det A4)" by simp
  with v show "det ?A = det A1 * det A4" by simp
qed
  
lemma det_swapcols: 
  assumes *: "k < n" "l < n" "k  l" and A: "A  carrier_mat n n"
  shows "det (swapcols k l A) = - det A"
proof -
  let ?B = "transpose_mat A"
  let ?C = "swaprows k l ?B"
  let ?D = "transpose_mat ?C"
  have C: "?C  carrier_mat n n" and B: "?B  carrier_mat n n"
    unfolding transpose_carrier_mat swaprows_carrier using A by auto
  show ?thesis 
    unfolding 
      swapcols_is_transp_swap_rows[OF A *(1-2)]
      det_transpose[OF C] det_swaprows[OF * B] det_transpose[OF A] ..
qed


lemma swap_row_to_front_det: "A  carrier_mat n n  I < n  det (swap_row_to_front A I)
  = (-1)^I * det A"
proof (induct I arbitrary: A)
  case (Suc I A)
  from Suc(3) have I: "I < n" by auto
  let ?I = "Suc I"
  let ?A = "swaprows I ?I A"
  have AA: "?A  carrier_mat n n" using Suc(2) by simp
  have "det (swap_row_to_front A (Suc I)) = det (swap_row_to_front ?A I)" by simp
  also have " = (-1)^I * det ?A" by (rule Suc(1)[OF AA I])
  also have "det ?A = -1 * det A" using det_swaprows[OF I Suc(3) _ Suc(2)] by simp
  finally show ?case by simp
qed simp

lemma swap_col_to_front_det: "A  carrier_mat n n  I < n  det (swap_col_to_front A I)
  = (-1)^I * det A"
proof (induct I arbitrary: A)
  case (Suc I A)
  from Suc(3) have I: "I < n" by auto
  let ?I = "Suc I"
  let ?A = "swapcols I ?I A"
  have AA: "?A  carrier_mat n n" using Suc(2) by simp
  have "det (swap_col_to_front A (Suc I)) = det (swap_col_to_front ?A I)" by simp
  also have " = (-1)^I * det ?A" by (rule Suc(1)[OF AA I])
  also have "det ?A = -1 * det A" using det_swapcols[OF I Suc(3) _ Suc(2)] by simp
  finally show ?case by simp
qed simp


lemma swap_row_to_front_four_block: assumes A1: "A1  carrier_mat n m1"
  and A2: "A2  carrier_mat n m2" 
  and A3: "A3  carrier_mat 1 m1" 
  and A4: "A4  carrier_mat 1 m2"
  shows "swap_row_to_front (four_block_mat A1 A2 A3 A4) n = four_block_mat A3 A4 A1 A2"
  by (subst swap_row_to_front_result[OF four_block_carrier_mat[OF A1 A4]], force,
  rule eq_matI, insert A1 A2 A3 A4, auto)

lemma swap_col_to_front_four_block: assumes A1: "A1  carrier_mat n1 m"
  and A2: "A2  carrier_mat n1 1" 
  and A3: "A3  carrier_mat n2 m" 
  and A4: "A4  carrier_mat n2 1"
  shows "swap_col_to_front (four_block_mat A1 A2 A3 A4) m = four_block_mat A2 A1 A4 A3"
  by (subst swap_col_to_front_result[OF four_block_carrier_mat[OF A1 A4]], force,
  rule eq_matI, insert A1 A2 A3 A4, auto)

lemma det_four_block_mat_lower_right_zero_col: assumes A1: "A1  carrier_mat 1 n"
  and A2: "A2  carrier_mat 1 1"
  and A3: "A3  carrier_mat n n"
  and A40: "A4 = (0m n 1)" 
  shows "det (four_block_mat A1 A2 A3 A4) = (-1)^n * det A2 * det A3" (is "det ?A = _")
proof -
  let ?B = "four_block_mat A3 A4 A1 A2"
  from four_block_carrier_mat[OF A3 A2]
  have B: "?B  carrier_mat (Suc n) (Suc n)" by simp
  from A40 have A4: "A4  carrier_mat n 1" by auto
  from arg_cong[OF swap_row_to_front_four_block[OF A3 A4 A1 A2], of det]
    swap_row_to_front_det[OF B, of n]
  have "det ?A = (-1)^n * det ?B" by auto
  also have "det ?B = det A3 * det A2"
    by (rule det_four_block_mat_upper_right_zero_col[OF A3 A40 A1 A2])
  finally show ?thesis by simp
qed
  
lemma det_four_block_mat_lower_left_zero_col: assumes A1: "A1  carrier_mat 1 1"
  and A2: "A2  carrier_mat 1 n"
  and A30: "A3 = (0m n 1)" 
  and A4: "A4  carrier_mat n n"
  shows "det (four_block_mat A1 A2 A3 A4) = det A1 * det A4" (is "det ?A = _")
proof -
  from A30 have A3: "A3  carrier_mat n 1" by auto
  let ?B = "four_block_mat A2 A1 A4 A3"
  from four_block_carrier_mat[OF A2 A3]
  have B: "?B  carrier_mat (Suc n) (Suc n)" by simp
  from arg_cong[OF swap_col_to_front_four_block[OF A2 A1 A4 A3], of det]
    swap_col_to_front_det[OF B, of n]
  have "det ?A = (-1)^n * det ?B" by auto
  also have "det ?B = (- 1) ^ n * det A1 * det A4"
    by (rule det_four_block_mat_lower_right_zero_col[OF A2 A1 A4 A30])
  also have "(-1)^n *  = (-1 * -1)^n * det A1 * det A4"
    unfolding power_mult_distrib by (simp add: ac_simps)
  finally show ?thesis by simp
qed

lemma det_addcol[simp]: 
  assumes l: "l < n" and k: "k  l" and A: "A  carrier_mat n n"
  shows "det (addcol a k l A) = det A"
proof -
  have "addcol a k l A = A * addrow_mat n a l k"
    using addcol_mat[OF A l].
  also have "det (A * addrow_mat n a l k) = det A * det (addrow_mat n a l k)"
    by(rule det_mult[OF A], auto)
  also have "det (addrow_mat n a l k) = 1"
    using det_addrow_mat[OF k[symmetric]].
  finally show ?thesis using A by simp
qed

definition "insert_index i  λi'. if i' < i then i' else Suc i'"

definition "delete_index i  λi'. if i' < i then i' else i' - Suc 0"

lemma insert_index[simp]:
  "i' < i  insert_index i i' = i'"
  "i'  i  insert_index i i' = Suc i'"
unfolding insert_index_def by auto


lemma delete_insert_index[simp]:
  "delete_index i (insert_index i i') = i'"
  unfolding insert_index_def delete_index_def by auto

lemma insert_delete_index:
  assumes i'i: "i'  i"
  shows "insert_index i (delete_index i i') = i'"
  unfolding insert_index_def delete_index_def using i'i by auto

definition "delete_dom p i  λi'. p (insert_index i i')"

definition "delete_ran p j  λi. delete_index j (p i)"

definition "permutation_delete p i = delete_ran (delete_dom p i) (p i)"

definition "insert_ran p j  λi. insert_index j (p i)"

definition "insert_dom p i j 
  λi'. if i' < i then p i' else if i' = i then j else p (i'-1)"

definition "permutation_insert i j p  insert_dom (insert_ran p j) i j"

lemmas permutation_delete_expand =
  permutation_delete_def[unfolded delete_dom_def delete_ran_def insert_index_def delete_index_def]

lemmas permutation_insert_expand =
  permutation_insert_def[unfolded insert_dom_def insert_ran_def insert_index_def delete_index_def]

lemma permutation_insert_inserted[simp]:
  "permutation_insert (i::nat) j p i = j"
  unfolding permutation_insert_expand by auto

lemma permutation_insert_base:
  assumes p: "p permutes {0..<n}"
  shows "permutation_insert n n p = p"
proof (rule ext)
  fix x show "permutation_insert n n p x = p x"
    apply (cases rule: linorder_cases[of "x" "n"])
    unfolding permutation_insert_expand
    using permutes_others[OF p] p by auto
qed

lemma permutation_insert_row_step:
  shows "permutation_insert (Suc i) j p  Fun.swap i (Suc i) id = permutation_insert i j p"
    (is "?l = ?r")
proof (rule ext)
  fix x show "?l x = ?r x"
    by (cases rule: linorder_cases[of "x" "i"])
      (auto simp add: swap_id_eq permutation_insert_expand)
qed

lemma permutation_insert_column_step:
  assumes p: "p permutes {0..<n}" and "j < n"
  shows "(Fun.swap j (Suc j) id)  (permutation_insert i (Suc j) p) = permutation_insert i j p"
    (is "?l = ?r")
proof (rule ext)
  fix x show "?l x = ?r x"
  proof (cases rule: linorder_cases[of "x" "i"])
    case less note x = this
      show ?thesis
        apply (cases rule: linorder_cases[of "p x" "j"])
        unfolding permutation_insert_expand using x by simp+
    next case equal thus ?thesis by simp
    next case greater note x = this
      show ?thesis
        apply (cases rule: linorder_cases[of "p (x-1)" "j"])
        unfolding permutation_insert_expand using x by simp+
  qed
qed

lemma delete_dom_image:
  assumes i: "i  {0..<Suc n}" (is "_  ?N")
  assumes iff: "i'  ?N. f i' = f i  i' = i"
  shows "delete_dom f i ` {0..<n} = f ` ?N - {f i}" (is "?L = ?R")
proof(unfold set_eq_iff, intro allI iffI)
  fix j'
  { assume L: "j'  ?L"
    then obtain i' where i': "i'  {0..<n}" and dj': "delete_dom f i i' = j'" by auto
    show "j'  ?R"
    proof(cases "i' < i")
      case True
        show ?thesis
          unfolding image_def
          unfolding Diff_iff
          unfolding mem_Collect_eq singleton_iff
        proof(intro conjI bexI)
          show "j'  f i"
          proof
            assume j': "j' = f i"
            hence "f i' = f i"
              using dj'[unfolded delete_dom_def insert_index_def] using True by simp
            thus "False" using iff i True by auto
          qed
          show "j' = f i'"
            using dj' True unfolding delete_dom_def insert_index_def by simp
        qed (insert i',simp)
      next case False
        show ?thesis
          unfolding image_def
          unfolding Diff_iff
          unfolding mem_Collect_eq singleton_iff
        proof(intro conjI bexI)
          show Si': "Suc i'  ?N" using i' by auto
          show "j'  f i"
          proof
            assume j': "j' = f i"
            hence "f (Suc i') = f i"
              using dj'[unfolded delete_dom_def insert_index_def] j' False by simp
            thus "False" using iff Si' False by auto
          qed
          show "j' = f (Suc i')"
            using dj' False unfolding delete_dom_def insert_index_def by simp
        qed
    qed
  }
  { assume R: "j'  ?R"
    then obtain i'
      where i': "i'  ?N" and j'fi: "j'  f i" and j'fi': "j' = f i'" by auto
    hence i'i: "i'  i" using iff by auto
    hence n: "n > 0" using i i' by auto
    show "j'  ?L"
    proof (cases "i' < i")
      case True show ?thesis
        proof
          show "j' = delete_dom f i i'"
            unfolding delete_dom_def insert_index_def using True j'fi' by simp
        qed (insert True i, simp)
      next case False show ?thesis
        proof
          show "i'-1  {0..<n}" using i' n by auto
          show "j' = delete_dom f i (i'-1)"
            unfolding delete_dom_def insert_index_def using False j'fi' i'i by auto
        qed
    qed
  }
qed

lemma delete_ran_image:
  assumes j: "j  {0..<Suc n}" (is "_  ?N")
  assumes fimg: "f ` {0..<n} =  ?N - {j}"
  shows "delete_ran f j ` {0..<n} = {0..<n}" (is "?L = ?R")
proof(unfold set_eq_iff, intro allI iffI)
  fix j'
  { assume L: "j'  ?L"
    then obtain i where i: "i  {0..<n}" and ij': "delete_ran f j i = j'" by auto
    have "f i  ?N - {j}" using fimg i by blast
    thus "j'  ?R" using ij' j unfolding delete_ran_def delete_index_def by auto
  }
  { assume R: "j'  ?R" show "j'  ?L"
    proof (cases "j' < j")
      case True
        hence "j'  ?N - {j}" using R by auto
        then obtain i where fij': "f i = j'" and i: "i  {0..<n}"
          unfolding fimg[symmetric] by auto
        have "delete_ran f j i = j'"
          unfolding delete_ran_def delete_index_def unfolding fij' using True by simp
        thus ?thesis using i by auto
      next case False
        hence "Suc j'  ?N - {j}" using R by auto
        then obtain i where fij': "f i = Suc j'" and i: "i  {0..<n}"
          unfolding fimg[symmetric] by auto
        have "delete_ran f j i = j'"
          unfolding delete_ran_def delete_index_def unfolding fij' using False by simp
        thus ?thesis using i by auto
    qed
  }
qed

lemma delete_index_inj_on:
  assumes iS: "i  S"
  shows "inj_on (delete_index i) S"
proof(intro inj_onI)
  fix x y
  assume eq: "delete_index i x = delete_index i y" and x: "x  S" and y: "y  S"
  have "x  i" "y  i" using x y iS by auto
  thus "x = y"
    using eq unfolding delete_index_def
    by(cases "x < i"; cases "y < i";simp)
qed

lemma insert_index_inj_on:
  shows "inj_on (insert_index i) S"
proof(intro inj_onI)
  fix x y
  assume eq: "insert_index i x = insert_index i y" and x: "x  S" and y: "y  S"
  show "x = y"
    using eq unfolding insert_index_def
    by(cases "x < i"; cases "y < i";simp)
qed


lemma delete_dom_inj_on:
  assumes i: "i  {0..<Suc n}" (is "_  ?N")
  assumes inj: "inj_on f ?N"
  shows "inj_on (delete_dom f i) {0..<n}"
proof (rule eq_card_imp_inj_on)
  have "card ?N = card (f ` ?N)" using card_image[OF inj]..
  hence "card {0..<n} = card (f ` ?N - {f i})" using i by auto
  also have "... = card (delete_dom f i ` {0..<n})"
    apply(subst delete_dom_image[symmetric])
    using i inj unfolding inj_on_def by auto
  finally show "card (delete_dom f i ` {0..<n}) = card {0..<n}"..
qed simp

lemma delete_ran_inj_on:
  assumes j: "j  {0..<Suc n}" (is "_  ?N")
  assumes img: "f ` {0..<n} =  ?N - {j}"
  shows "inj_on (delete_ran f j) {0..<n}"
  apply (rule eq_card_imp_inj_on)
  unfolding delete_ran_image[OF j img] by simp+

lemma permutation_delete_bij_betw:
  assumes i: "i  {0 ..< Suc n}" (is "_  ?N")
  assumes bij: "bij_betw p ?N ?N"
  shows "bij_betw (permutation_delete p i) {0..<n} {0..<n}" (is "bij_betw ?p _ _")
proof-
  have inj: "inj_on p ?N" using bij_betw_imp_inj_on[OF bij].
  have ran: "p ` ?N = ?N" using bij_betw_imp_surj_on[OF bij].
  hence j: "p i  ?N" using i by auto
  have "i'?N. p i' = p i  i' = i" using inj i unfolding inj_on_def by auto
  from delete_dom_image[OF i this]
  have "delete_dom p i ` {0..<n} = ?N - {p i}" unfolding ran.
  from delete_ran_inj_on[OF j this] delete_ran_image[OF j this]
  show ?thesis unfolding permutation_delete_def
    using bij_betw_imageI by blast
qed

lemma permutation_delete_permutes:
  assumes p: "p permutes {0 ..< Suc n}" (is "_ permutes ?N")
      and i: "i < Suc n"
  shows "permutation_delete p i permutes {0..<n}" (is "?p permutes ?N'")
proof (rule bij_imp_permutes, rule permutation_delete_bij_betw)
  have pi: "p i < Suc n" using p i by auto
  show "bij_betw p ?N ?N" using permutes_imp_bij[OF p].
  fix x assume "x  {0..<n}" hence x: "x  n" by simp
    show "?p x = x"
    proof(cases "x < i")
      case True thus ?thesis
        unfolding permutation_delete_def using x i by simp
      next case False
        hence "p (Suc x) = Suc x" using x permutes_others[OF p] by auto
        thus ?thesis
        unfolding permutation_delete_expand using False pi x by simp
    qed
qed (insert i,auto)

lemma permutation_insert_delete:
  assumes p: "p permutes {0..<Suc n}"
      and i: "i < Suc n"
  shows "permutation_insert i (p i) (permutation_delete p i) = p"
    (is "?l = _")
proof
  fix i'
  show "?l i' = p i'"
  proof (cases rule: linorder_cases[of "i'" "i"])
    case less note i'i = this
      show ?thesis
      proof (cases "p i = p i'")
        case True
          hence "i = i'" using permutes_inj[OF p] injD by metis
          hence False using i'i by auto
          thus ?thesis by auto
        next case False thus ?thesis
          unfolding permutation_insert_expand permutation_delete_expand
          using i'i by auto
      qed
    next case equal thus ?thesis unfolding permutation_insert_expand by simp
    next case greater hence i'i: "i' > i" by auto
      hence cond: "¬ i' - 1 < i" using i'i by simp
      show ?thesis
      proof (cases rule: linorder_cases[of "p i'" "p i"])
        case less
          hence pd: "permutation_delete p i (i'-1) = p i'"
            unfolding permutation_delete_expand
            using i'i cond by auto
          show ?thesis
            unfolding permutation_insert_expand pd
            using i'i less by simp
        next case equal
          hence "i = i'" using permutes_inj[OF p] injD by metis
          hence False using i'i by auto
          thus ?thesis by auto
        next case greater
          hence pd: "permutation_delete p i (i'-1) = p i' - 1"
            unfolding permutation_delete_expand
            using i'i cond by simp
          show ?thesis
            unfolding permutation_insert_expand pd
            using i'i greater by auto
      qed
  qed
qed

lemma insert_index_exclude[simp]:
  "insert_index i i'  i" unfolding insert_index_def by auto

lemma insert_index_image:
  assumes i: "i < Suc n"
  shows "insert_index i ` {0..<n} = {0..<Suc n} - {i}" (is "?L = ?R")
proof(unfold set_eq_iff, intro allI iffI)
  let ?N = "{0..<Suc n}"
  fix i'
  { assume L: "i'  ?L"
    then obtain i''
      where ins: "i' = insert_index i i''" and i'': "i''  {0..<n}" by auto
    show "i'  ?N - {i}"
    proof(rule DiffI)
      show "i'  ?N" using ins unfolding insert_index_def using i'' by auto
      show "i'  {i}"
        unfolding singleton_iff
        unfolding ins unfolding insert_index_def by auto
    qed
  }
  { assume R: "i'  ?R"
    show "i'  ?L"
    proof(cases rule: linorder_cases[of "i'" "i"])
      case less
        hence i': "i'  {0..<n}" using i by auto
        hence "insert_index i i' = i'" unfolding insert_index_def using less by auto
        thus ?thesis using i' by force
      next case equal
        hence False using R by auto
        thus ?thesis by auto
      next case greater
        hence i'': "i'-1  {0..<n}" using i R by auto
        hence "insert_index i (i'-1) = i'"
          unfolding insert_index_def using greater by auto
        thus ?thesis using i'' by force
    qed
  }
qed

lemma insert_ran_image:
  assumes j: "j < Suc n"
  assumes img: "f ` {0..<n} = {0..<n}"
  shows "insert_ran f j ` {0..<n} = {0..<Suc n} - {j}" (is "?L = ?R")
proof -
  have "?L = (λi. insert_index j (f i)) ` {0..<n}" unfolding insert_ran_def..
  also have "... = (insert_index j  f) ` {0..<n}" by auto
  also have "... = insert_index j ` f ` {0..<n}" by auto
  also have "... = insert_index j ` {0..<n}" unfolding img by auto
  finally show ?thesis using insert_index_image[OF j] by auto
qed

lemma insert_dom_image:
  assumes i: "i < Suc n" and j: "j < Suc n"
    and img: "f ` {0..<n} = {0..<Suc n} - {j}" (is "_ = ?N - _")
  shows "insert_dom f i j ` ?N = ?N" (is "?f ` _ = _")
proof(unfold set_eq_iff,intro allI iffI)
  fix j'
  { assume "j'  ?f ` ?N"
    then obtain i' where i': "i'  ?N" and j': "j' = ?f i'" by auto
    show "j'  ?N"
    proof (cases rule:linorder_cases[of "i'" "i"])
      case less
        hence "i'  {0..<n}" using i by auto
        hence "f i' < Suc n" using imageI[of i' "{0..<n}" f] img by auto
        thus ?thesis
          unfolding j' unfolding insert_dom_def using less by auto
      next case equal
        thus ?thesis unfolding j' insert_dom_def using j by auto
      next case greater
        hence "i'-1  {0..<n}" using i' by auto
        hence "f (i'-1) < Suc n" using imageI[of "i'-1" "{0..<n}" f] img by auto
        thus ?thesis
          unfolding j' insert_dom_def using greater by auto
    qed
  }
  { assume j': "j'  ?N" show "j'  ?f ` ?N"
    proof (cases "j' = j")
      case True
        hence "?f i = j'" unfolding insert_dom_def by auto
        thus ?thesis using i by auto
      next case False
        hence j': "j'  ?N - {j}" using j j' by auto
        then obtain i' where j'fi: "j' = f i'" and i': "i'  {0..<n}"
          unfolding img[symmetric] by auto
        show ?thesis
        proof(cases "i' < i")
          case True thus ?thesis unfolding j'fi insert_dom_def using i' by auto
          next case False
            hence "?f (Suc i') = j'" unfolding j'fi insert_dom_def using i' by auto
            thus ?thesis using i' by auto
        qed
    qed
  }
qed

lemma insert_ran_inj_on:
  assumes inj: "inj_on f {0..<n}" and j: "j < Suc n"
  shows "inj_on (insert_ran f j) {0..<n}" (is "inj_on ?f _")
proof (rule inj_onI)
  fix i i'
  assume i: "i  {0..<n}" and i': "i'  {0..<n}" and eq: "?f i = ?f i'"
  note eq2 = eq[unfolded insert_ran_def insert_index_def]
  have "f i = f i'"
  proof (cases "f i < j")
    case True
      moreover have "f i' < j" apply (rule ccontr) using eq2 True by auto
      ultimately show ?thesis using eq2 by auto
    next case False
      moreover have "¬ f i' < j" apply (rule ccontr) using eq2 False by auto
      ultimately show ?thesis using eq2 by auto
  qed
  from inj_onD[OF inj this i i'] show "i = i'".
qed

lemma insert_dom_inj_on:
  assumes inj: "inj_on f {0..<n}"
      and i: "i < Suc n" and j: "j < Suc n"
      and img: "f ` {0..<n} = {0..<Suc n} - {j}" (is "_ = ?N - _")
  shows "inj_on (insert_dom f i j) ?N"
  apply(rule eq_card_imp_inj_on)
  unfolding insert_dom_image[OF i j img] by simp+

lemma permutation_insert_bij_betw:
  assumes q: "q permutes {0..<n}" and i: "i < Suc n" and j: "j < Suc n"
  shows "bij_betw (permutation_insert i j q) {0..<Suc n} {0..<Suc n}"
    (is "bij_betw ?q ?N _")
proof (rule bij_betw_imageI)
  have img: "q ` {0..<n} = {0..<n}" using permutes_image[OF q].
  show "?q ` ?N = ?N"
    unfolding permutation_insert_def
    using insert_dom_image[OF i j insert_ran_image[OF j permutes_image[OF q]]].
  have inj: "inj_on q {0..<n}"
    apply(rule subset_inj_on) using permutes_inj[OF q] by auto
  show "inj_on ?q ?N"
    unfolding permutation_insert_def
    using insert_dom_inj_on[OF _ i j]
    using insert_ran_inj_on[OF inj j] insert_ran_image[OF j img] by auto
qed

lemma permutation_insert_permutes:
  assumes q: "q permutes {0..<n}"
      and i: "i < Suc n" and j: "j < Suc n"
  shows "permutation_insert i j q permutes {0..<Suc n}" (is "?p permutes ?N")
  using permutation_insert_bij_betw[OF q i j]
proof (rule bij_imp_permutes)
  fix i' assume "i'  ?N"
  moreover hence "q (i'-1) = i'-1" using permutes_others[OF q] by auto
  ultimately show "?p i' = i'"
    unfolding permutation_insert_expand using i j by auto
qed

lemma permutation_fix:
  assumes i: "i < Suc n" and j: "j < Suc n"
  shows "{ p. p permutes {0..<Suc n}  p i = j } =
         permutation_insert i j ` { q. q permutes {0..<n} }"
    (is "?L = ?R")
  unfolding set_eq_iff
proof(intro allI iffI)
  let ?N = "{0..<Suc n}"
  fix p
  { assume "p  ?L"
    hence p: "p permutes ?N" and pij: "p i = j" by auto
    show "p  ?R"
      unfolding mem_Collect_eq
      using permutation_delete_permutes[OF p i]
      using permutation_insert_delete[OF p i,symmetric]
      unfolding pij by auto
  }
  { assume "p  ?R"
    then obtain q where pq: "p = permutation_insert i j q" and q: "q permutes {0..<n}" by auto
    hence "p i = j" unfolding permutation_insert_expand by simp
    thus "p  ?L"
      using pq permutation_insert_permutes[OF q i j] by auto
  }
qed

lemma permutation_split_ran:
  assumes j: "j  S"
  shows "{ p. p permutes S } = (i  S. { p. p permutes S  p i = j })"
  (is "?L = ?R")
  unfolding set_eq_iff
proof(intro allI iffI)
  fix p
  { assume "p  ?L"
    hence p: "p permutes S" by auto
    obtain i where i: "i  S" and pij: "p i = j" using j permutes_image[OF p] by force
    thus "p  ?R" using p by auto
  }
  { assume "p  ?R"
    then obtain i
      where p: "p permutes S" and i: "i  S" and pij: "p i = j"
      by auto
    show "p  ?L"
      unfolding mem_Collect_eq using p.
  }
qed

lemma permutation_disjoint_dom:
  assumes i: "i  S" and i': "i'  S" and j: "j  S" and ii': "i  i'"
  shows "{ p. p permutes S  p i = j }  { p. p permutes S  p i' = j } = {}"
    (is "?L  ?R = {}")
proof -
  {
    fix p assume "p  ?L  ?R"
    hence p: "p permutes S" and "p i = j" and "p i' = j" by auto
    hence "p i = p i'" by auto
    note injD[OF permutes_inj[OF p] this]
    hence False using ii' by auto
  }
  thus ?thesis by auto
qed

lemma permutation_disjoint_ran:
  assumes i: "i  S" and j: "j  S" and j': "j'  S" and jj': "j  j'"
  shows "{ p. p permutes S  p i = j }  { p. p permutes S  p i = j' } = {}"
    (is "?L  ?R = {}")
proof -
  {
    fix p assume "p  ?L  ?R"
    hence "p permutes S" and "p i = j" and "p i = j'" by auto
    hence False using jj' by auto
  }
  thus ?thesis by auto
qed

lemma permutation_insert_inj_on:
  assumes "i < Suc n"
  assumes "j < Suc n"
  shows "inj_on (permutation_insert i j) { q. q permutes {0..<n} }"
  (is "inj_on ?f ?S")
proof (rule inj_onI)
  fix q q'
  assume "q  ?S" "q'  ?S"
  hence q: "q permutes {0..<n}" and q': "q' permutes {0..<n}" by auto
  assume "?f q = ?f q'"
  hence eq: "permutation_insert i j q = permutation_insert i j q'" by auto
  note eq = cong[OF eq refl, unfolded permutation_insert_expand]
  show qq': "q = q'"
  proof(rule ext)
    fix x
    have foo: "Suc x - 1 = x" by auto
    show "q x = q' x"
    proof(cases "x < i")
      case True thus ?thesis apply(cases "q x < j";cases "q' x < j") using eq[of x] by auto
      next case False thus ?thesis
        apply(cases "q x < j";cases "q' x < j") using eq[of "Suc x"] by auto
    qed
  qed
qed

lemma signof_permutation_insert:
  assumes p: "p permutes {0..<n}" and i: "i < Suc n" and j: "j < Suc n"
  shows "signof (permutation_insert i j p) = (-1::'a::ring_1)^(i+j) * signof p"
proof -
  { fix j assume "j  n"
    hence "signof (permutation_insert n (n-j) p) = (-1::'a)^(n+(n-j)) * signof p"
    proof(induct "j")
      case 0 show ?case using permutation_insert_base[OF p] by (simp add: mult_2[symmetric])
      next case (Suc j)
        hence Sjn: "Suc j  n" and j: "j < n" and Sj: "n - Suc j < n" by auto
        hence n0: "n > 0" by auto
        have ease: "Suc (n - Suc j) = n - j" using j by auto
        let ?swap = "Fun.swap (n - Suc j) (n - j) id"
        let ?prev = "permutation_insert n (n - j) p"
        have "signof (permutation_insert n (n - Suc j) p) = signof (?swap  ?prev)"
          unfolding permutation_insert_column_step[OF p Sj, unfolded ease]..
        also have "... = signof ?swap * signof ?prev"
          proof(rule signof_compose)
            show "?swap permutes {0..<Suc n}" by (rule permutes_swap_id,auto)
            show "?prev permutes {0..<Suc n}" by (rule permutation_insert_permutes[OF p],auto)
          qed
        also have "signof ?swap = -1"
          proof-
            have "n - Suc j < n - j" using Sjn by simp
            thus ?thesis unfolding signof_def sign_swap_id by simp
          qed
        also have "signof ?prev = (-1::'a) ^ (n + (n - j)) * signof p" using Suc(1) j by auto
        also have "(-1) * ... =  (-1) ^ (1 + n + (n - j)) * signof p" by simp
        also have "n - j = 1 + (n - Suc j)" using j by simp
        also have "1 + n + ... = 2 + (n + (n - Suc j))" by simp
        also have "(-1::'a) ^ ... = (-1) ^ 2 * (-1) ^ (n + (n - Suc j))" by simp
        also have "... = (-1) ^ (n + (n - Suc j))" by simp
        finally show ?case.
    qed
  }
  note col = this
  have nj: "n - j  n" using j by auto
  have row_base: "signof (permutation_insert n j p) = (-1::'a)^(n+j) * signof p"
    using col[OF nj] using j by simp
  { fix i assume "i  n"
    hence "signof (permutation_insert (n-i) j p) = (-1::'a)^((n-i)+j) * signof p"
    proof (induct i)
      case 0 show ?case using row_base by auto
      next case (Suc i)
        hence Sin: "Suc i  n" and i: "i  n" and Si: "n - Suc i < n" by auto
        have ease: "Suc (n - Suc i) = n - i" using Sin by auto
        let ?prev = "permutation_insert (n-i) j p"
        let ?swap = "Fun.swap (n - Suc i) (n-i) id"
        have "signof (permutation_insert (n - Suc i) j p) = signof (?prev  ?swap)"
          using permutation_insert_row_step[of "n - Suc i"] unfolding ease by auto
        also have "... = signof ?prev * signof ?swap"
          proof(rule signof_compose)
            show "?swap permutes {0..<Suc n}" by (rule permutes_swap_id,auto)
            show "?prev permutes {0..<Suc n}"
              apply(rule permutation_insert_permutes[OF p]) using j by auto
          qed
        also have "signof ?swap = (-1)"
          proof-
            have "n - Suc i < n - i" using Sin by simp
            thus ?thesis unfolding signof_def sign_swap_id by simp
          qed
        also have "signof ?prev = (-1::'a) ^ (n - i + j) * signof p"
          using Suc(1)[OF i].
        also have "... * (-1) = (-1) ^ Suc (n - i + j) * signof p"
          by auto
        also have "Suc (n - i + j) = Suc (Suc (n - Suc i + j))"
          using Sin by auto
        also have "(-1::int) ^ ... = (-1) ^ (n - Suc i + j)" by auto
        ultimately show ?case by auto
    qed
  }
  note row = this
  have ni: "n - i  n" using i by auto
  show ?thesis using row[OF ni] using i by simp
qed

lemma foo:
  assumes i: "i < Suc n" and j: "j < Suc n"
  assumes q: "q permutes {0..<n}"
  shows "{(i', permutation_insert i j q i') | i'. i'  {0..<Suc n} - {i} } =
  { (insert_index i i'', insert_index j (q i'')) | i''. i'' < n }" (is "?L = ?R")
  unfolding set_eq_iff
proof(intro allI iffI)
  fix ij
  { assume "ij  ?L"
    then obtain i'
      where ij: "ij = (i', permutation_insert i j q i')" and i': "i' < Suc n" and i'i: "i'  i"
      by auto
    show "ij  ?R" unfolding mem_Collect_eq
    proof(intro exI conjI)
      show "ij = (insert_index i (delete_index i i'), insert_index j (q (delete_index i i')))"
        using ij unfolding insert_delete_index[OF i'i] using i'i
        unfolding permutation_insert_expand insert_index_def delete_index_def by auto
      show "delete_index i i' < n" using i' i i'i unfolding delete_index_def by auto
    qed
  }
  { assume "ij  ?R"
    then obtain i''
      where ij: "ij = (insert_index i i'', insert_index j (q i''))" and i'': "i'' < n"
      by auto
    show "ij  ?L" unfolding mem_Collect_eq
    proof(intro exI conjI)
      show "insert_index i i''  {0..<Suc n} - {i}"
        unfolding insert_index_image[OF i,symmetric] using i'' by auto
      have "insert_index j (q i'') = permutation_insert i j q (insert_index i i'')"
        unfolding permutation_insert_expand insert_index_def by auto
      thus "ij = (insert_index i i'', permutation_insert i j q (insert_index i i''))"
        unfolding ij by auto
    qed
  }
qed

definition "mat_delete A i j 
  mat (dim_row A - 1) (dim_col A - 1) (λ(i',j').
    A $$ (if i' < i then i' else Suc i', if j' < j then j' else Suc j'))"

lemma mat_delete_dim[simp]:
  "dim_row (mat_delete A i j) = dim_row A - 1"
  "dim_col (mat_delete A i j) = dim_col A - 1"
  unfolding mat_delete_def by auto

lemma mat_delete_carrier:
  assumes A: "A  carrier_mat m n"
  shows "mat_delete A i j  carrier_mat (m-1) (n-1)" unfolding mat_delete_def using A by auto

lemma "mat_delete_index":
  assumes A: "A  carrier_mat (Suc n) (Suc n)"
      and i: "i < Suc n" and j: "j < Suc n"
      and i': "i' < n" and j': "j' < n"
  shows "A $$ (insert_index i i', insert_index j j') = mat_delete A i j $$ (i', j')"
  unfolding mat_delete_def
  unfolding permutation_insert_expand
  unfolding insert_index_def
  using A i j i' j' by auto

definition "cofactor A i j = (-1)^(i+j) * det (mat_delete A i j)"


lemma laplace_expansion_column:
  assumes A: "(A :: 'a :: comm_ring_1 mat)  carrier_mat n n"
      and j: "j < n"
  shows "det A = (i<n. A $$ (i,j) * cofactor A i j)"
proof -
  define l where "l = n-1"
  have A: "A  carrier_mat (Suc l) (Suc l)"
   and jl: "j < Suc l" using A j unfolding l_def by auto
  let ?N = "{0 ..< Suc l}"
  define f where "f = (λp i. A $$ (i, p i))"
  define g where "g = (λp. prod (f p) ?N)"
  define h where "h = (λp. signof p * g p)"
  define Q where "Q = { q . q permutes {0..<l} }"
  have jN: "j  ?N" using jl by auto
  have disj: "i  ?N. i'  ?N. i  i' 
    {p. p permutes ?N  p i = j}  {p. p permutes ?N  p i' = j} = {}"
    using permutation_disjoint_dom[OF _ _ jN] by auto
  have fin: "i?N. finite {p. p permutes ?N  p i = j}"
    using finite_permutations[of ?N] by auto

  have "det A = sum h { p. p permutes ?N }"
    using det_def'[OF A] unfolding h_def g_def f_def using atLeast0LessThan by auto
  also have "... = sum h (i?N. {p. p permutes ?N  p i = j})"
    unfolding permutation_split_ran[OF jN]..
  also have "... = (i?N. sum h { p | p. p permutes ?N  p i = j})"
    using sum.UNION_disjoint[OF _ fin disj] by auto
  also {
    fix i assume "i  ?N"
    hence i: "i < Suc l" by auto
    have "sum h { p | p. p permutes ?N  p i = j} = sum h (permutation_insert i j ` Q)"
      using permutation_fix[OF i jl] unfolding Q_def by auto
    also have "... = sum (h  permutation_insert i j) Q"
      unfolding Q_def using sum.reindex[OF permutation_insert_inj_on[OF i jl]].
    also have "... = ( q  Q.
      signof (permutation_insert i j q) * prod (f (permutation_insert i j q)) ?N)"
      unfolding h_def g_def Q_def by simp
    also {
      fix q assume "q  Q"
      hence q: "q permutes {0..<l}" unfolding Q_def by auto
      let ?p = "permutation_insert i j q"
      have fin: "finite (?N - {i})" by auto
      have notin: "i  ?N - {i}" by auto
      have close: "insert i (?N - {i}) = ?N" using notin i by auto
      have "prod (f ?p) ?N = f ?p i * prod (f ?p) (?N-{i})"
        unfolding prod.insert[OF fin notin, unfolded close] by auto
      also have "... = A $$ (i, j) * prod (f ?p) (?N-{i})"
        unfolding f_def Q_def using permutation_insert_inserted by simp
      also have "prod (f ?p) (?N-{i}) = prod (λi'. A $$ (i', permutation_insert i j q i')) (?N-{i})"
        unfolding f_def..
      also have "... = prod (λij. A $$ ij) ((λi'. (i', permutation_insert i j q i')) ` (?N-{i}))"
        (is "_ = prod _ ?part")
        unfolding prod.reindex[OF inj_on_convol_ident] o_def..
      also have "?part = {(i', permutation_insert i j q i') | i'. i'  ?N-{i} }"
        unfolding image_def by metis
      also have "... = {(insert_index i i'', insert_index j (q i'')) | i''. i'' < l}"
        unfolding foo[OF i jl q]..
      also have "... = ((λi''. (insert_index i i'', insert_index j (q i''))) ` {0..<l})"
        unfolding image_def by auto
      also have "prod (λij. A $$ ij)... = prod ((λij. A $$ ij)  (λi''. (insert_index i i'', insert_index j (q i'')))) {0..<l}"
        proof(subst prod.reindex[symmetric])
          have 1: "inj (λi''. (i'', insert_index j (q i'')))" using inj_on_convol_ident.
          have 2: "inj (λ(i'',j). (insert_index i i'', j))"
            apply (intro injI) using injD[OF insert_index_inj_on[of _ UNIV]] by auto
          have "inj (λi''. (insert_index i i'', insert_index j (q i'')))"
            using inj_compose[OF 2 1] unfolding o_def by auto
          thus "inj_on (λi''. (insert_index i i'', insert_index j (q i''))) {0..<l}"
            using subset_inj_on by auto
        qed auto
      also have "... = prod (λi''. A $$ (insert_index i i'', insert_index j (q i''))) {0..<l}"
        by auto
      also have "... = prod (λi''. mat_delete A i j $$ (i'', q i'')) {0..<l}"
      proof (rule prod.cong[OF refl], unfold atLeastLessThan_iff, elim conjE)
        fix x assume x: "x < l"
        show "A $$ (insert_index i x, insert_index j (q x)) = mat_delete A i j $$ (x, q x)"
          apply(rule mat_delete_index[OF A i jl]) using q x by auto
      qed
      finally have "prod (f ?p) ?N =
        A $$ (i, j) * (i'' = 0..< l. mat_delete A i j $$ (i'', q i''))"
        by auto
      hence "signof ?p * prod (f ?p) ?N  = (-1::'a)^(i+j) * signof q * ..."
        unfolding signof_permutation_insert[OF q i jl] by auto
    }
    hence "... = ( q  Q. (-1)^(i+j) * signof q *
      A $$ (i, j) * (i'' = 0 ..< l. mat_delete A i j $$ (i'', q i'')))"
      by(intro sum.cong[OF refl],auto)
    also have "... = (  q  Q. A $$ (i, j) * (-1)^(i+j) *
       ( signof q * (i'' = 0..< l. mat_delete A i j $$ (i'', q i'')) ) )"
      by (intro sum.cong[OF refl],auto)
    also have "... = A $$ (i, j) * (-1)^(i+j) *
      (  q  Q. signof q * (i''= 0 ..< l. mat_delete A i j $$ (i'', q i'')) )"
      unfolding sum_distrib_left by auto
    also have "... = (A $$ (i, j) * (-1)^(i+j) * det (mat_delete A i j))"
      unfolding det_def'[OF mat_delete_carrier[OF A]]
      unfolding Q_def by auto
    finally have "sum h {p | p. p permutes ?N  p i = j} = A $$ (i, j) * cofactor A i j"
      unfolding cofactor_def by auto
  }
  hence "... = (i?N. A $$ (i,j) * cofactor A i j)" by auto
  finally show ?thesis unfolding atLeast0LessThan using A j unfolding l_def by auto
qed

lemma laplace_expansion_row:
  assumes A: "(A :: 'a :: comm_ring_1 mat)  carrier_mat n n"
      and i: "i < n"
    shows "det A = (j<n. A $$ (i,j) * cofactor A i j)"
proof -
  have "det A = det (AT)" using det_transpose[OF A] by simp
  also have " = (j<n. AT $$ (j, i) * cofactor AT j i)" 
    by (rule laplace_expansion_column[OF _ i], insert A, auto)
  also have " = (j<n. A $$ (i,j) * cofactor A i j)" unfolding cofactor_def
  proof (rule sum.cong[OF refl], rule arg_cong2[of _ _ _ _ "λ x y. x * y"], goal_cases)
    case (1 j)
    thus ?case using A i by auto
  next
    case (2 j)
    have "det (mat_delete AT j i) = det ((mat_delete AT j i)T)" 
      by (subst det_transpose, insert A, auto simp: mat_delete_def)
    also have "(mat_delete AT j i)T = mat_delete A i j" 
      unfolding mat_delete_def using A by auto
    finally show ?case by (simp add: ac_simps)
  qed
  finally show ?thesis .
qed


lemma degree_det_le: assumes " i j. i < n  j < n  degree (A $$ (i,j))  k"
  and A: "A  carrier_mat n n" 
shows "degree (det A)  k * n" 
proof -
  {
    fix p
    assume p: "p permutes {0..<n}"
    have "(x = 0..<n. degree (A $$ (x, p x)))  (x = 0..<n. k)"     
      by (rule sum_mono[OF assms(1)], insert p, auto)
    also have " = k * n" unfolding sum_constant by simp
    also note calculation 
  } note * = this
  show ?thesis unfolding det_def'[OF A]
    by (rule degree_sum_le, insert *, auto simp: finite_permutations signof_def 
      intro!: order.trans[OF degree_prod_sum_le])
qed

lemma upper_triangular_imp_det_eq_0_iff:
  fixes A :: "'a :: idom mat"
  assumes "A  carrier_mat n n" and "upper_triangular A"
  shows "det A = 0  0  set (diag_mat A)"
  using assms by (auto simp: det_upper_triangular)

lemma det_identical_columns:
  assumes A: "A  carrier_mat n n"  
    and ij: "i  j"
    and i: "i < n" and j: "j < n"
    and r: "col A i = col A j"
  shows "det A = 0"
proof-
  have "det A = det AT" using det_transpose[OF A] ..
  also have "... = 0" 
  proof (rule det_identical_rows[of _ n i j])
     show "row (transpose_mat A) i = row (transpose_mat A) j"
       using A i j r by auto
  qed (auto simp add: assms)
  finally show ?thesis .
qed

definition adj_mat :: "'a :: comm_ring_1 mat  'a mat" where
  "adj_mat A = mat (dim_row A) (dim_col A) (λ (i,j). cofactor A j i)" 

lemma adj_mat: assumes A: "A  carrier_mat n n"
  shows "adj_mat A  carrier_mat n n"
  "A * adj_mat A = det A m 1m n" 
  "adj_mat A * A = det A m 1m n" 
proof -
  from A have dims: "dim_row A = n" "dim_col A = n" by auto
  show aA: "adj_mat A  carrier_mat n n" unfolding adj_mat_def dims by simp  
  {
    fix i j
    assume ij: "i < n" "j < n" 
    define B where "B = mat n n (λ (i',j'). if i' = j then A $$ (i,j') else A $$ (i',j'))" 
    have "(A * adj_mat A) $$ (i,j) = ( k < n. A $$ (i,k) * cofactor A j k)" 
      unfolding times_mat_def scalar_prod_def adj_mat_def using ij A by (auto intro: sum.cong)
    also have " = ( k < n. A $$ (i,k) * (-1)^(j + k) * det (mat_delete A j k))" 
      unfolding cofactor_def by (auto intro: sum.cong)
    also have " = ( k < n. B $$ (j,k) * (-1)^(j + k) * det (mat_delete B j k))" 
      by (rule sum.cong[OF refl], intro arg_cong2[of _ _ _ _ "λ x y. y * _ * det x"], insert A ij,
        auto simp: B_def mat_delete_def)
    also have " = ( k < n. B $$ (j,k) * cofactor B j k)" 
      unfolding cofactor_def by (simp add: ac_simps)
    also have " = det B" 
      by (rule laplace_expansion_row[symmetric], insert ij, auto simp: B_def)
    also have " = (if i = j then det A else 0)" 
    proof (cases "i = j")
      case True
      hence "B = A" using A by (auto simp add: B_def)
      with True show ?thesis by simp
    next
      case False
      have "det B = 0" 
        by (rule Determinant.det_identical_rows[OF _ False ij], insert A ij, auto simp: B_def)
      with False show ?thesis by simp
    qed
    also have " = (det A m 1m n) $$ (i,j)"  using ij by auto
    finally have "(A * adj_mat A) $$ (i, j) = (det A m 1m n) $$ (i, j)" .
  } note main = this
  show "A * adj_mat A = det A m 1m n"
    by (rule eq_matI[OF main], insert A aA, auto)
  (* now the completely symmetric version *)
  {
    fix i j
    assume ij: "i < n" "j < n" 
    define B where "B = mat n n (λ (i',j'). if j' = i then A $$ (i',j) else A $$ (i',j'))" 
    have "(adj_mat A * A) $$ (i,j) = ( k < n. A $$ (k,j) * cofactor A k i)" 
      unfolding times_mat_def scalar_prod_def adj_mat_def using ij A by (auto intro: sum.cong)
    also have " = ( k < n. A $$ (k,j) * (-1)^(k + i) * det (mat_delete A k i))" 
      unfolding cofactor_def by (auto intro: sum.cong)
    also have " = ( k < n. B $$ (k,i) * (-1)^(k + i) * det (mat_delete B k i))" 
      by (rule sum.cong[OF refl], intro arg_cong2[of _ _ _ _ "λ x y. y * _ * det x"], insert A ij,
        auto simp: B_def mat_delete_def)
    also have " = ( k < n. B $$ (k,i) * cofactor B k i)" 
      unfolding cofactor_def by (simp add: ac_simps)
    also have " = det B" 
      by (rule laplace_expansion_column[symmetric], insert ij, auto simp: B_def)
    also have " = (if i = j then det A else 0)" 
    proof (cases "i = j")
      case True
      hence "B = A" using A by (auto simp add: B_def)
      with True show ?thesis by simp
    next
      case False
      have "det B = 0" 
        by (rule Determinant.det_identical_columns[OF _ False ij], insert A ij, auto simp: B_def)
      with False show ?thesis by simp
    qed
    also have " = (det A m 1m n) $$ (i,j)"  using ij by auto
    finally have "(adj_mat A * A) $$ (i, j) = (det A m 1m n) $$ (i, j)" .
  } note main = this
  show "adj_mat A * A = det A m 1m n"
    by (rule eq_matI[OF main], insert A aA, auto)
qed

definition "replace_col A b k = mat (dim_row A) (dim_col A) (λ (i,j). if j = k then b $ i else A $$ (i,j))"

lemma cramer_lemma_mat:  
  assumes A: "A  carrier_mat n n" 
  and x: "x  carrier_vec n" 
  and k: "k < n" 
shows "det (replace_col A (A *v x) k) = x $ k * det A" 
proof -
  define b where "b = A *v x" 
  have b: "b  carrier_vec n" using A x unfolding b_def by auto
  let ?Ab = "replace_col A b k" 
  have Ab: "?Ab  carrier_mat n n" using A by (auto simp: replace_col_def)
  have "x $ k * det A = (det A v x) $ k" using A k x by auto
  also have "det A v x = det A v (1m n *v x)" using x by auto
  also have " = (det A m 1m n) *v x" using A x by auto
  also have " = (adj_mat A * A) *v x" using adj_mat[OF A] by simp
  also have " = adj_mat A *v b" using adj_mat[OF A] A x unfolding b_def
    by (metis assoc_mult_mat_vec)
  also have " $ k = row (adj_mat A) k  b" using adj_mat[OF A] b k by auto
  also have " = det (replace_col A b k)" unfolding scalar_prod_def using b k A
    by (subst laplace_expansion_column[OF Ab k], auto intro!: sum.cong arg_cong[of _ _ det] 
      arg_cong[of _ _ "λ x. _ * x"] eq_matI
      simp: replace_col_def adj_mat_def Matrix.row_def cofactor_def mat_delete_def ac_simps)
  finally show ?thesis unfolding b_def by simp
qed


end

Theory Determinant_Impl

(*  
    Author:      René Thiemann 
                 Akihisa Yamada
    License:     BSD
*)
section ‹Code Equations for Determinants›

text ‹We compute determinants on arbitrary rings by applying elementary row-operations
  to bring a matrix on upper-triangular form. Then the determinant can be determined
  by multiplying all entries on the diagonal. Moreover the final result has to be divided
  by a factor which is determined by the row-operations that we performed. To this end,
  we require a division operation on the element type.

  The algorithm is parametric in a selection function for the pivot-element, e.g., for 
  matrices over polynomials it turned out that selecting a polynomial of minimal degree
  is beneficial.›

theory Determinant_Impl
imports
  Polynomial_Interpolation.Missing_Polynomial
  "HOL-Computational_Algebra.Polynomial_Factorial"
  Determinant
begin

type_synonym 'a det_selection_fun = "(nat × 'a)list  nat"

definition det_selection_fun :: "'a det_selection_fun  bool" where 
  "det_selection_fun f = ( xs. xs  []  f xs  fst ` set xs)"


lemma det_selection_funD: "det_selection_fun f  xs  []  f xs  fst ` set xs"
  unfolding det_selection_fun_def by auto

definition mute_fun :: "('a :: comm_ring_1  'a  'a × 'a × 'a)  bool" where
  "mute_fun f = ( x y x' y' g. f x y = (x',y',g)  y  0 
    x = x' * g  y * x' = x * y')"

context
  fixes sel_fun :: "'a ::idom_divide det_selection_fun"
begin

subsection ‹Properties of triangular matrices›

text ‹
  Each column of a triangular matrix should satisfy the following property.
›

definition triangular_column::"nat  'a mat  bool"
  where "triangular_column j A  i. j < i  i < dim_row A  A $$ (i,j) = 0"

lemma triangular_columnD [dest]:
  "triangular_column j A  j < i  i < dim_row A  A $$ (i,j) = 0"
  unfolding triangular_column_def by auto

lemma triangular_columnI [intro]:
  "(i. j < i  i < dim_row A  A $$ (i,j) = 0)  triangular_column j A"
  unfolding triangular_column_def by auto

text ‹
  The following predicate states that the first $k$ columns satisfy
  triangularity.
›

definition triangular_to:: "nat  'a mat  bool"
  where "triangular_to k A == j. j<k  triangular_column j A"

lemma triangular_to_triangular: "upper_triangular A = triangular_to (dim_row A) A"
  unfolding triangular_to_def triangular_column_def upper_triangular_def
  by auto

lemma triangular_toD [dest]:
  "triangular_to k A  j < k  j < i  i < dim_row A  A $$ (i,j) = 0"
  unfolding triangular_to_def triangular_column_def by auto

lemma triangular_toI [intro]:
  "(i j. j < k  j < i  i < dim_row A  A $$ (i,j) = 0)  triangular_to k A"
  unfolding triangular_to_def triangular_column_def by auto

lemma triangle_growth:
  assumes tri:"triangular_to k A"
  and col:"triangular_column k A"
  shows "triangular_to (Suc k) A"
  unfolding triangular_to_def
proof (intro allI impI)
  fix i assume iSk:"i < Suc k"
  show "triangular_column i A"
  proof (cases "i = k")
    case True
      then show ?thesis using col by auto next
    case False
      then have "i < k" using iSk by auto
      thus ?thesis using tri unfolding triangular_to_def by auto
  qed
qed

lemma triangle_trans: "triangular_to k A  k > k'  triangular_to k' A"
  by (intro triangular_toI, elim triangular_toD, auto)


subsection ‹Algorithms for Triangulization›

context 
  fixes mf :: "'a  'a  'a × 'a × 'a"
begin

private fun mute :: "'a  nat  nat  'a × 'a mat  'a × 'a mat" where
  "mute A_ll k l (r,A) = (let p = A $$ (k,l) in if p = 0 then (r,A) else 
    case mf A_ll p of (q',p',g)  
      (r * q', addrow (-p') k l (multrow k q' A)))" 

lemma mute_preserves_dimensions:
  assumes "mute q k l (r,A) = (r',A')"
  shows [simp]: "dim_row A' = dim_row A" and [simp]: "dim_col A' = dim_col A"
using assms by (auto simp: Let_def split: if_splits prod.splits)

text ‹
  Algorithm @{term "mute k l"} makes $k$-th row $l$-th column element to 0.
›

lemma mute_makes_0 :
 assumes mute_fun: "mute_fun mf"
 assumes "mute (A $$ (l,l)) k l (r,A) = (r',A')"
 "l < dim_row A"
 "l < dim_col A"
 "k < dim_row A"
 "k  l"
 shows "A' $$ (k,l) = 0"
proof -
  define a where "a = A $$ (l, l)"
  define b where "b = A $$ (k, l)"
  let ?mf = "mf (A $$ (l, l)) (A $$ (k, l))"
  obtain q' p' g where id: "?mf = (q',p',g)" by (cases ?mf, auto)
  note mf = mute_fun[unfolded mute_fun_def, rule_format, OF id]
  from assms show ?thesis
  unfolding mat_addrow_def using mf id by (auto simp: ac_simps Let_def split: if_splits)
qed

text ‹It will not touch unexpected rows.›
lemma mute_preserves:
  "mute q k l (r,A) = (r',A') 
   i < dim_row A 
   j < dim_col A 
   l < dim_row A 
   k < dim_row A 
   i  k 
   A' $$ (i,j) = A $$ (i,j)"
   by (auto simp: Let_def split: if_splits prod.splits)

text ‹It preserves $0$s in the touched row.›
lemma mute_preserves_0:
  "mute q k l (r,A) = (r',A') 
   i < dim_row A 
   j < dim_col A 
   l < dim_row A 
   k < dim_row A 
   A $$ (i,j) = 0 
   A $$ (l,j) = 0 
   A' $$ (i,j) = 0"
   by (auto simp: Let_def split: if_splits prod.splits)

text ‹Hence, it will respect partially triangular matrix.›
lemma mute_preserves_triangle:
 assumes rA' : "mute q k l (r,A) = (r',A')"
 and triA: "triangular_to l A"
 and lk: "l < k"
 and kr: "k < dim_row A"
 and lr: "l < dim_row A"
 and lc: "l < dim_col A"
 shows "triangular_to l A'"
proof (rule triangular_toI)
  fix i j
  assume jl: "j < l" and ji: "j < i" and ir': "i < dim_row A'"
  then have A0: "A $$ (i,j) = 0" using triA rA' by auto
  moreover have "A $$ (l,j) = 0" using triA jl jl lr by auto
  moreover have jc:"j < dim_col A" using jl lc by auto
  moreover have ir: "i < dim_row A" using ir' rA' by auto
  ultimately show "A' $$ (i,j) = 0"
    using mute_preserves_0[OF rA'] lr kr by auto
qed


text ‹Recursive application of @{const mute}

private fun sub1 :: "'a  nat  nat  'a × 'a mat  'a × 'a mat"
where "sub1 q 0 l rA = rA"
  | "sub1 q (Suc k) l rA = mute q (l + Suc k) l (sub1 q k l rA)"

lemma sub1_preserves_dimensions[simp]:
  "sub1 q k l (r,A) = (r',A')  dim_row A' = dim_row A"
  "sub1 q k l (r,A) = (r',A')  dim_col A' = dim_col A"
proof (induction k arbitrary: r' A')
  case (Suc k)
    moreover obtain r' A' where rA': "sub1 q k l (r, A) = (r', A')" by force
    moreover fix r'' A'' assume "sub1 q (Suc k) l (r, A) = (r'', A'')" 
    ultimately show "dim_row A'' = dim_row A" "dim_col A'' = dim_col A" by auto
qed auto

lemma sub1_closed [simp]:
  "sub1 q k l (r,A) = (r',A')  A  carrier_mat m n  A'  carrier_mat m n"
  unfolding carrier_mat_def by auto

lemma sub1_preserves_diagnal:
  assumes "sub1 q k l (r,A) = (r',A')"
  and "l < dim_col A"
  and "k + l < dim_row A"
  shows "A' $$ (l,l) = A $$ (l,l)"
using assms
proof -
  show "k + l < dim_row A  sub1 q k l (r,A) = (r',A') 
    A' $$ (l,l) = A $$ (l,l)"
  proof (induction k arbitrary: r' A')
    case (Suc k)
      obtain r'' A'' where rA''[simp]: "sub1 q k l (r,A) = (r'',A'')" by force
      have [simp]:"dim_row A'' = dim_row A" and [simp]:"dim_col A'' = dim_col A"
        using snd_conv sub1_preserves_dimensions[OF rA''] by auto
      have "A'' $$ (l,l) = A $$ (l,l)" using assms Suc by auto
      have rA': "mute q (l + Suc k) l (r'', A'') = (r',A')"
        using Suc by auto
      show ?case using subst mute_preserves[OF rA'] Suc assms by auto
  qed auto
qed

text ‹Triangularity is respected by @{const sub1}.›
lemma sub1_preserves_triangle:
  assumes "sub1 q k l (r,A) = (r',A')"
  and tri: "triangular_to l A"
  and lr: "l < dim_row A"
  and lc: "l < dim_col A"
  and lkr: "l + k < dim_row A"
  shows "triangular_to l A'"
using assms
proof -
  show "sub1 q k l (r,A) = (r',A')  l + k < dim_row A 
    triangular_to l A'"
  proof (induction k arbitrary: r' A')
  case (Suc k)
    then have "sub1 q (Suc k) l (r,A) = (r',A')" by auto
    moreover obtain r'' A''
      where rA'': "sub1 q k l (r, A) = (r'',A'')" by force
    ultimately
      have rA': "mute q (Suc (l + k)) l (r'',A'') = (r',A')" by auto
    have "triangular_to l A''" using rA'' Suc by auto
    thus ?case
      using Suc assms mute_preserves_triangle[OF rA'] rA'' by auto
  qed (insert assms,auto)
qed

context
  assumes mf: "mute_fun mf"
begin
lemma sub1_makes_0s:
  assumes "sub1 (A $$ (l,l)) k l (r,A) = (r',A')"
  and lr: "l < dim_row A"
  and lc: "l < dim_col A"
  and li: "l < i"
  and "i  k + l"
  and "k + l < dim_row A"
  shows "A' $$ (i,l) = 0"
using assms
proof -
  show "sub1 (A $$ (l,l)) k l (r,A) = (r',A')  i  k + l  k + l < dim_row A 
    A' $$ (i,l) = 0"
  using lr lc li
  proof (induction k arbitrary: r' A')
  case (Suc k)
    obtain r' A' where rA': "sub1 (A $$ (l,l)) k l (r, A) = (r',A')" by force
    fix r'' A''
    from sub1_preserves_diagnal[OF rA'] have AA': "A $$ (l, l) = A' $$ (l, l)" using Suc(2-) by auto
    assume "sub1 (A $$ (l,l)) (Suc k) l (r, A) = (r'',A'')"
    then have rA'': "mute (A $$ (l,l)) (Suc (l + k)) l (r', A') = (r'', A'')"
      using rA' by simp
    have ir: "i < dim_row A" using Suc by auto
    have il: "i  l" using li by auto
    have lr': "l < dim_row A'" using lr rA' by auto
    have lc': "l < dim_col A'" using lc rA' by auto
    have Slkr': "Suc (l+k) < dim_row A'" using Suc rA' by auto
    show "A'' $$ (i,l) = 0"
    proof (cases "Suc(l + k) = i")
      case True {
        have l: "Suc (l + k)  l" by auto
        show ?thesis
          using mute_makes_0[OF mf rA''[unfolded AA'] lr' lc' Slkr' l] ir il rA'
          by (simp add:True)
      } next
      case False {
        then have ikl: "i  k+l" using Suc by auto
        have ir': "i < dim_row A'" using ir rA' by auto
        have lc': "l < dim_col A'" using lc rA' by auto
        have IH: "A' $$ (i,l) = 0" using rA' Suc False by auto
        thus ?thesis using mute_preserves[OF rA'' ir' lc'] rA' False Suc
          by simp
      }
    qed
  qed auto
qed

lemma sub1_triangulizes_column:
  assumes rA': "sub1 (A $$ (l,l)) (dim_row A - Suc l) l (r,A) = (r',A')"
  and tri:"triangular_to l A"
  and r: "dim_row A > 0"
  and lr: "l < dim_row A"
  and lc: "l < dim_col A"
  shows "triangular_column l A'"
proof (intro triangular_columnI)
  fix i
  assume li: "l < i"
  assume ir: "i < dim_row A'"
    also have "... = dim_row A" using sub1_preserves_dimensions[OF rA'] by auto
    also have "... = dim_row A - l + l" using lr li by auto
    finally have ir2: "i  dim_row A - l + l" by auto
  show "A' $$ (i,l) = 0"
    apply (subst sub1_makes_0s[OF rA' lr lc])
    using li ir assms
    by auto
qed

text ‹
  The algorithm @{const sub1} increases the number of columns that form triangle.
›
lemma sub1_grows_triangle:
  assumes rA': "sub1 (A $$ (l,l)) (dim_row A - Suc l) l (r,A) = (r',A')"
  and r: "dim_row A > 0"
  and tri:"triangular_to l A"
  and lr: "l < dim_row A"
  and lc: "l < dim_col A"
  shows "triangular_to (Suc l) A'"
proof -
  have "triangular_to l A'"
    using sub1_preserves_triangle[OF rA'] assms by auto
  moreover have "triangular_column l A'"
    using sub1_triangulizes_column[OF rA'] assms by auto
  ultimately show ?thesis by (rule triangle_growth)
qed
end

subsection ‹Finding Non-Zero Elements›

private definition find_non0 :: "nat  'a mat  nat option" where
  "find_non0 l A = (let is = [Suc l ..< dim_row A];
    Ais = filter (λ (i,Ail). Ail  0) (map (λ i. (i, A $$ (i,l))) is)
    in case Ais of []  None | _  Some (sel_fun Ais))"

lemma find_non0: assumes sel_fun: "det_selection_fun sel_fun"
  and res: "find_non0 l A = Some m"
  shows "A $$ (m,l)  0" "l < m" "m < dim_row A"
proof -
  let ?xs = "filter (λ (i,Ail). Ail  0) (map (λ i. (i, A $$ (i,l))) [Suc l..<dim_row A])"
  from res[unfolded find_non0_def Let_def]
  have xs: "?xs  []" and m: "m = sel_fun ?xs"
    by (cases ?xs, auto)+
  from det_selection_funD[OF sel_fun xs, folded m] show "A $$ (m, l)  0" "l < m" "m < dim_row A" by auto
qed

text ‹
  If @{term "find_non0 l A"} fails,
  then $A$ is already triangular to $l$-th column.
›

lemma find_non0_all0:
  "find_non0 l A = None  triangular_column l A"
proof (intro triangular_columnI) 
  fix i
  let ?xs = "filter (λ (i,Ail). Ail  0) (map (λ i. (i, A $$ (i,l))) [Suc l..<dim_row A])"
  assume none: "find_non0 l A = None" and li: "l < i" "i < dim_row A"
  from none have xs: "?xs = []"
    unfolding find_non0_def Let_def by (cases ?xs, auto)
  from li have "(i, A $$ (i,l))  set (map (λ i. (i, A $$ (i,l))) [Suc l..<dim_row A])" by auto
  with xs show "A $$ (i,l) = 0"
    by (metis (mono_tags) xs case_prodI filter_empty_conv)
qed

subsection ‹Determinant Preserving Growth of Triangle›

text ‹
  The algorithm @{const sub1} does not preserve determinants when it hits
  a $0$-valued diagonal element. To avoid this case, we introduce the following
  operation:
›

private fun sub2 :: "nat  nat  'a × 'a mat  'a × 'a mat"
  where "sub2 d l (r,A) = (
    case find_non0 l A of None  (r,A)
    | Some m  let A' = swaprows m l A in sub1 (A' $$ (l,l)) (d - Suc l) l (-r, A'))"

lemma sub2_preserves_dimensions[simp]:
  assumes rA': "sub2 d l (r,A) = (r',A')"
  shows "dim_row A' = dim_row A  dim_col A' = dim_col A"
proof (cases "find_non0 l A")
  case None then show ?thesis using rA' by auto next
  case (Some m) then show ?thesis using rA' by (cases "m = l", auto simp: Let_def)
qed

lemma sub2_closed [simp]:
  "sub2 d l (r,A) = (r',A')  A  carrier_mat m n  A'  carrier_mat m n"
  unfolding carrier_mat_def by auto

context 
  assumes sel_fun: "det_selection_fun sel_fun"
begin

lemma sub2_preserves_triangle:
  assumes rA': "sub2 d l (r,A) = (r',A')"
  and tri: "triangular_to l A"
  and lc: "l < dim_col A"
  and ld: "l < d"
  and dr: "d  dim_row A"
  shows "triangular_to l A'"
proof -
  have lr: "l < dim_row A" using ld dr by auto
  show ?thesis
  proof (cases "find_non0 l A")
    case None then show ?thesis using rA' tri by simp next
    case (Some m) {
      have lm : "l < m" and mr : "m < dim_row A"
        using find_non0[OF sel_fun Some] by auto
      let ?A1 = "swaprows m l A"
  
      have tri'': "triangular_to l ?A1"
      proof (intro triangular_toI)
        fix i j
        assume jl:"j < l" and ji:"j < i" and ir1: "i < dim_row ?A1"
        have jm: "j < m" using jl lm by auto
        have ir: "i < dim_row A" using ir1 by auto
        have jc: "j < dim_col A" using jl lc by auto
        show "?A1 $$ (i, j) = 0"
        proof (cases "m=i")
          case True {
            then have li: "l  i" using lm by auto
            hence "?A1 $$ (i,j) = A $$ (l,j)" using ir jc m=i by auto
            also have "... = 0" using tri jl lr by auto
            finally show ?thesis.
           }  next
          case False show ?thesis
            proof (cases "l=i")
              case True {
                then have "?A1 $$ (i,j) = A $$ (m,j)"
                  using ir jc mi by auto
                thus "?A1 $$ (i,j) = 0" using tri jl jm mr by auto
              } next
              case False {
                then have "?A1 $$ (i,j) = A $$ (i,j)"
                  using ir jc mi by auto
                thus "?A1 $$ (i,j) = 0" using tri jl ji ir by auto
              }
           qed
        qed
      qed
  
      let ?rA3 = "sub1 (?A1 $$ (l,l)) (d - Suc l) l (-r,?A1)"
      have [simp]: "dim_row ?A1 = dim_row A  dim_col ?A1 = dim_col A" by auto
      have rA'2: "?rA3 = (r',A')" using rA' Some by (simp add: Let_def)
      have "l + (d - Suc l) < dim_row A" using ld dr by auto
      thus ?thesis
        using sub1_preserves_triangle[OF rA'2 tri''] lr lc rA' by auto
    }
  qed
qed

lemma sub2_grows_triangle:
  assumes mf: "mute_fun mf"
  and rA': "sub2 (dim_row A) l (r,A) = (r',A')"
  and tri: "triangular_to l A"
  and lc: "l < dim_col A"
  and lr: "l < dim_row A"
  shows "triangular_to (Suc l) A'"
proof (rule triangle_growth)
  show "triangular_to l A'"
    using sub2_preserves_triangle[OF rA' tri lc lr] by auto
    next
  have r0: "0 < dim_row A" using lr by auto
  show "triangular_column l A'"
    proof (cases "find_non0 l A")
      case None {
        then have "A' = A" using rA' by simp
        moreover have "triangular_column l A"  using find_non0_all0[OF None].
        ultimately show ?thesis by auto
      } next
      case (Some m) {
        have lm: "l < m" and mr: "m < dim_row A"
          using find_non0[OF sel_fun Some] by auto
        let ?A = "swaprows m l A"
        have tri2: "triangular_to l ?A"
          proof
            fix i j assume jl: "j < l" and ji:"j < i" and ir: "i < dim_row ?A"
            show "?A $$ (i,j) = 0"
            proof (cases "i = m")
              case True {
                then have "?A $$ (i,j) = A $$ (l,j)"
                  using jl lc ir by simp
                also have "... = 0"
                  using triangular_toD[OF tri jl jl] lr by auto
                finally show ?thesis by auto
              } next
              case False show ?thesis
                proof (cases "i = l")
                  case True {
                    then have "?A $$ (i,j) = A $$ (m,j)"
                      using jl lc ir by auto
                    also have "... = 0"
                      using triangular_toD[OF tri jl] jl lm mr by auto
                    finally show ?thesis by auto
                  } next
                  case False {
                    then have "?A $$ (i,j) = A $$ (i,j)"
                      using i  m jl lc ir by auto
                    thus ?thesis using tri jl ji ir by auto
                  }
                qed
            qed
          qed
        have rA'2: "sub1 (?A $$ (l,l)) (dim_row ?A - Suc l) l (-r, ?A) = (r',A')"
          using lm Some rA' by (simp add: Let_def)
        show ?thesis
          using sub1_triangulizes_column[OF mf rA'2 tri2] r0 lr lc by auto
      }
    qed
qed
end

subsection ‹Recursive Triangulization of Columns›

text ‹
  Now we recursively apply @{const sub2} to make the entire matrix to be triangular.
›

private fun sub3 :: "nat  nat  'a × 'a mat  'a × 'a mat"
  where "sub3 d 0 rA = rA"
  | "sub3 d (Suc l) rA = sub2 d l (sub3 d l rA)"

lemma sub3_preserves_dimensions[simp]:
  "sub3 d l (r,A) = (r',A')  dim_row A' = dim_row A"
  "sub3 d l (r,A) = (r',A')  dim_col A' = dim_col A"
proof (induction l arbitrary: r' A')
  case (Suc l)
    obtain r' A' where rA': "sub3 d l (r, A) = (r', A')" by force
    fix r'' A'' assume rA'': "sub3 d (Suc l) (r, A) = (r'', A'')" 
    then show "dim_row A'' = dim_row A" "dim_col A'' = dim_col A"
    using Suc rA' by auto
qed auto

lemma sub3_closed[simp]:
  "sub3 k l (r,A) = (r',A')  A  carrier_mat m n  A'  carrier_mat m n"
  unfolding carrier_mat_def by auto

lemma sub3_makes_triangle:
  assumes mf: "mute_fun mf"
  and sel_fun: "det_selection_fun sel_fun"
  and "sub3 (dim_row A) l (r,A) = (r',A')"
  and "l  dim_row A"
  and "l  dim_col A"
  shows "triangular_to l A'"
  using assms
proof -
  show "sub3 (dim_row A) l (r,A) = (r',A')  l  dim_row A  l  dim_col A 
    triangular_to l A'"
  proof (induction l arbitrary:r' A')
    case (Suc l)
      then have Slr: "Suc l  dim_row A" and Slc: "Suc l  dim_col A" by auto
      hence lr: "l < dim_row A" and lc: "l < dim_col A" by auto
      moreover obtain r'' A''
        where rA'': "sub3 (dim_row A) l (r,A) = (r'',A'')" by force
      ultimately have IH: "triangular_to l A''" using Suc by auto
      have [simp]:"dim_row A'' = dim_row A" and [simp]:"dim_col A'' = dim_col A"
        using Slr Slc rA'' by auto
      fix r' A'
      assume "sub3 (dim_row A) (Suc l) (r, A) = (r', A')"
      then have rA': "sub2 (dim_row A'') l (r'',A'') = (r',A')"
        using rA'' by auto
      show "triangular_to (Suc l) A'"
        using sub2_grows_triangle[OF sel_fun mf rA'] lr lc rA'' IH by auto
  qed auto
qed

subsection ‹Triangulization›

definition triangulize :: "'a mat  'a × 'a mat"
where "triangulize A = sub3 (dim_row A) (dim_row A) (1,A)"

lemma triangulize_preserves_dimensions[simp]:
  "triangulize A = (r',A')  dim_row A' = dim_row A"
  "triangulize A = (r',A')  dim_col A' = dim_col A"
  unfolding triangulize_def by auto

lemma triangulize_closed[simp]:
  "triangulize A = (r',A')  A  carrier_mat m n  A'  carrier_mat m n"
  unfolding carrier_mat_def by auto

context
  assumes mf: "mute_fun mf"
  and sel_fun: "det_selection_fun sel_fun"
begin

theorem triangulized:
  assumes "A  carrier_mat n n"
  and "triangulize A = (r',A')"
  shows "upper_triangular A'"
proof (cases "0<n")
  case True
    have rA': "sub3 (dim_row A) (dim_row A) (1,A) = (r',A')"
      using assms unfolding triangulize_def by auto
    have nr:"n = dim_row A" and nc:"n = dim_col A" and nr':"n = dim_row A'"
    using assms by auto
    thus ?thesis
      unfolding triangular_to_triangular
      using sub3_makes_triangle[OF mf sel_fun rA'] True by auto
    next
  case False
    then have nr':"dim_row A' = 0" using assms by auto
    thus ?thesis unfolding upper_triangular_def by auto
qed

subsection‹Divisor will not be 0›

text ‹
  Here we show that each sub-algorithm will not make $r$
  of the input/output pair $(r,A)$ to 0.
  The algorithm @{term "sub1 A_ll k l (r,A)"} requires $A_{l,l} \neq 0$.
›

lemma sub1_divisor [simp]:
  assumes rA': "sub1 q k l (r, A) = (r',A')"
  and r0: "r  0"
  and All: "q  0"
  and "k + l < dim_row A "
  and lc: "l < dim_col A"
  shows "r'  0"
using assms
proof -
  show "sub1 q k l (r,A) = (r',A')  k + l < dim_row A  r'  0"
  proof (induction k arbitrary: r' A')
    case (Suc k)
      obtain r'' A'' where rA'': "sub1 q k l (r, A) = (r'', A'')" by force
      then have IH: "r''  0" using Suc by auto
      obtain q' l' g where mf_id: "mf q (A'' $$ (Suc (l + k), l)) = (q',l',g)" by (rule prod_cases3)      
      define fact where "fact = (if A'' $$ (Suc (l+k),l) = 0 then 1 else q')"
      note mf = mf[unfolded mute_fun_def, rule_format, OF mf_id]
      have All: "q  0"
        using sub1_preserves_diagnal[OF rA'' lc] All Suc by auto
      moreover have "fact  0" unfolding fact_def using All mf by auto
      moreover assume "sub1 q (Suc k) l (r,A) = (r',A')"
        then have "mute q (Suc (l + k)) l (r'',A'') = (r',A')"
          using rA'' by auto
        hence "r'' * fact = r'"
          unfolding mute.simps fact_def Let_def mf_id using IH by (auto split: if_splits)
      ultimately show "r'  0" using IH by auto
  qed (insert r0, simp)
qed

text ‹The algorithm @{term "sub2"} will not require such a condition.›
lemma sub2_divisor [simp]:
  assumes rA': "sub2 k l (r, A) = (r',A')"
  and lk: "l < k"
  and kr: "k  dim_row A"
  and lc: "l < dim_col A"
  and r0: "r  0"
  shows "r'  0"
using assms
proof (cases "find_non0 l A") {
  case (Some m)
    then have Aml0: "A $$ (m,l)  0" using find_non0[OF sel_fun] by auto
    have md: "m < dim_row A" using find_non0[OF sel_fun Some] lk kr by auto
    let ?A'' = "swaprows m l A"
    have rA'2: "sub1 (?A'' $$ (l,l)) (k - Suc l) l (-r, ?A'') = (r',A')"
      using rA' Some by (simp add: Let_def)
    have All0: "?A'' $$ (l,l)  0" using Aml0 md lk kr lc by auto
    show ?thesis using sub1_divisor[OF rA'2 _ All0] r0 lk kr lc by simp
} qed auto

lemma sub3_divisor [simp]:
  assumes "sub3 d l (r,A) = (r'',A'')"
  and "l  d"
  and "d  dim_row A"
  and "l  dim_col A"
  and r0: "r  0"
  shows "r''  0"
  using assms
proof -
  show
    "sub3 d l (r,A) = (r'',A'') 
     l  d  d  dim_row A  l  dim_col A  ?thesis"
  proof (induction l arbitrary: r'' A'')
    case 0
      then show ?case using r0 by simp
      next
    case (Suc l)
      obtain r' A' where rA': "sub3 d l (r,A) = (r',A')" by force
      then have [simp]:"dim_row A' = dim_row A" and [simp]:"dim_col A' = dim_col A"
        by auto
      from rA' have "r'  0" using Suc r0 by auto
      moreover have "sub2 d l (r',A') = (r'',A'')"
        using rA' Suc by simp
      ultimately show ?case using sub2_divisor using Suc by simp
  qed
qed

theorem triangulize_divisor:
  assumes A: "A  carrier_mat d d"
  shows "triangulize A = (r',A')  r'  0"
unfolding triangulize_def
proof -
  assume rA': "sub3 (dim_row A) (dim_row A) (1, A) = (r', A')"
  show ?thesis using sub3_divisor[OF rA'] A by auto 
qed

subsection ‹Determinant Preservation Results›

text ‹
  For each sub-algorithm $f$,
  we show $f(r,A) = (r',A')$ implies @{term "r * det A' = r' * det A"}.
›

lemma mute_det:
  assumes "A  carrier_mat n n"
  and rA': "mute q k l (r,A) = (r',A')"
  and "k < n"
  and "l < n"
  and "k  l"
  shows "r * det A' = r' * det A"
proof (cases "A $$ (k,l) = 0")
  case True
  thus ?thesis using assms by auto
next
  case False
  obtain p' q' g where mf_id: "mf q (A $$ (k,l)) = (q',p',g)" by (rule prod_cases3)
  let ?All = "q'"
  let ?Akl = "- p'"
  let ?B = "multrow k ?All A"
  let ?C = "addrow ?Akl k l ?B"
  have "r * det A' = r * det ?C"  using assms by (simp add: Let_def mf_id False)
  also have "det ?C = det ?B" using assms by (auto simp: det_addrow)
  also have " = ?All * det A" using assms det_multrow by auto
  also have "r *  = (r * ?All) * det A" by simp
  also have r: "r * ?All = r'" using assms by (simp add: Let_def mf_id False)
  finally show ?thesis.
qed

lemma sub1_det:
  assumes A: "A  carrier_mat n n"
  and sub1: "sub1 q k l (r,A) = (r'',A'')"
  and r0: "r  0"
  and All0: "q  0"
  and l: "l + k < n"
  shows "r * det A'' = r'' * det A"
  using sub1 l
proof (induction k arbitrary: A'' r'')
  case 0
  then show ?case by auto
next
  case (Suc k)
  let ?rA' = "sub1 q k l (r,A)"
  obtain r' A' where rA':"?rA' = (r',A')" by force
  have A':"A'  carrier_mat n n" using sub1_closed[OF rA'] A by auto
  have IH: "r * det A' = r' * det A" using Suc assms rA' by auto
  assume "sub1 q (Suc k) l (r,A) = (r'',A'')"
  then have rA'':"mute q (Suc (l+k)) l (r',A') = (r'',A'')" using rA' by auto
  hence lem: "r' * det A'' = r'' * det A'"
    using assms Suc A' mute_det[OF A' rA''] by auto
  hence "r * r' * det A'' = r * r'' * det A'" by auto
  also from IH have "... = r'' * r' * det A" by auto
  finally have *: "r * r' * det A'' = r'' * r' * det A" .
  then have "r * r' * det A'' div r' = r'' * r' * det A div r'" by presburger
  moreover have "r'  0"
    using r0 sub1_divisor[OF rA'] All0 Suc A by auto
  ultimately show ?case using * by auto
qed

lemma sub2_det:
  assumes A: "A  carrier_mat d d"
  and rA': "sub2 d l (r,A) = (r',A')"
  and r0: "r  0"
  and ld: "l < d"
  shows "r * det A' = r' * det A"
proof (cases "find_non0 l A")
  case None then show ?thesis using assms by auto next
  case (Some m) {
    then have lm: "l < m" and md: "m < d"
      using A find_non0[OF sel_fun Some] ld by auto
    hence "m  l" by auto
    let ?A'' = "swaprows m l A"
    have rA'2: "sub1 (?A'' $$ (l,l)) (d - Suc l) l (-r, ?A'') = (r',A')"
      using rA' Some by (simp add: Let_def)
    have A'': "?A''  carrier_mat d d" using A by auto
    hence A''ll0: "?A'' $$ (l,l)  0"
      using find_non0[OF sel_fun Some] ld by auto
    hence "-r * det A' = r' * det ?A''"
      using sub1_det[OF A'' rA'2] ld A r0 by auto
    also have "r * ... = -r * r' * det A"
      using det_swaprows[OF md ld ml A] by auto
    finally have "r * -r * det A' = -r * r' * det A" by auto
    thus ?thesis using r0 by auto
  }
qed

lemma sub3_det:
  assumes A:"A  carrier_mat d d"
  and "sub3 d l (r,A) = (r'',A'')"
  and r0: "r  0"
  and "l  d"
  shows "r * det A'' = r'' * det A"
  using assms
proof -
  have d: "d = dim_row A" using A by auto
  show "sub3 d l (r,A) = (r'',A'')  l  d  r * det A'' = r'' * det A"
  proof (induction l arbitrary: r'' A'')
    case (Suc l)
      let ?rA' = "sub3 d l (r,A)"
      obtain r' A' where rA':"?rA' = (r',A')" by force
      then have rA'': "sub2 d l (r',A') = (r'',A'')"
        using Suc by auto
      have A': "A'  carrier_mat d d" using A rA' rA'' by auto
      have r'0: "r'  0" using r0 sub3_divisor[OF rA'] A Suc by auto
      have "r' * det A'' = r'' * det A'"
        using Suc r'0 A by(subst sub2_det[OF A' rA''],auto)
      also have "r * ... = r'' * (r * det A')" by auto
      also have "r * det A' = r' * det A" using Suc rA' by auto
      also have "r'' * ... div r' = r'' * r' * det A div r'" by (simp add: ac_simps)
      finally show "r * det A'' = r'' * det A" using r'0 
        by (metis r * det A' = r' * det A r' * det A'' = r'' * det A' 
          mult.left_commute mult_cancel_left)
  qed simp
qed

theorem triangulize_det:
  assumes A: "A  carrier_mat d d"
  and rA': "triangulize A = (r',A')"
  shows "det A * r' = det A'"
proof -
  have rA'2: "sub3 d d (1,A) = (r',A')"
    using A rA' unfolding triangulize_def by auto
  show ?thesis
  proof (cases "d = 0")
    case True
      then have A': "A'  carrier_mat 0 0" using A rA'2 by auto
      have rA'3: "(r',A') = (1,A)" using True rA'2 by simp
      thus ?thesis by auto
      next
    case False
      then show ?thesis using sub3_det[OF A rA'2] assms by auto
  qed
qed
end

subsection ‹Determinant Computation›

definition det_code :: "'a mat  'a" where
  "det_code A = (if dim_row A = dim_col A then
     case triangulize A of (m,A') 
       prod_list (diag_mat A') div m
   else 0)"

lemma det_code[simp]: assumes sel_fun: "det_selection_fun sel_fun"
  and mf: "mute_fun mf"
  shows "det_code A = det A"
  using det_code_def[simp]
proof (cases "dim_row A = dim_col A")
  case True
  then have A: "A  carrier_mat (dim_row A) (dim_row A)" unfolding carrier_mat_def by auto
  obtain r' A' where rA': "triangulize A = (r',A')" by force
  from triangulize_divisor[OF mf sel_fun A] rA' have r'0: "r'  0" by auto
  from triangulize_det[OF mf sel_fun A rA'] have det': "det A * r' = det A'" by auto
  from triangulized[OF mf sel_fun A, unfolded rA'] have tri': "upper_triangular A'" by simp
  have A': "A'  carrier_mat (dim_row A') (dim_row A')"
    using triangulize_closed[OF rA' A] by auto
  from tri' have tr: "triangular_to (dim_row A') A'" by auto
  have "det_code A = prod_list (diag_mat A') div r'" using rA' True by simp
  also have "prod_list (diag_mat A') = det A'"
    unfolding det_upper_triangular[OF tri' A'] ..
  also have " = det A * r'" by (simp add: det')
  also have " div r' = det A" using r'0 by auto
  finally show ?thesis .
qed (simp add: det_def)

end
end

text ‹Now we can select an arbitrary selection and mute function. This will be important for computing
  resultants over polynomials, where usually a polynomial with small degree is preferable.

  The default however is to use the first element.›

definition trivial_mute_fun :: "'a :: comm_ring_1  'a  'a × 'a × 'a" where
  "trivial_mute_fun x y = (x,y,1)"

lemma trivial_mute_fun[simp,intro]: "mute_fun trivial_mute_fun"
  unfolding mute_fun_def trivial_mute_fun_def by auto

definition fst_sel_fun :: "'a det_selection_fun" where
  "fst_sel_fun x = fst (hd x)" 

lemma fst_sel_fun[simp]: "det_selection_fun fst_sel_fun"
  unfolding det_selection_fun_def fst_sel_fun_def by auto

context
  fixes measure :: "'a  nat" 
begin
private fun select_min_main where 
  "select_min_main m i ((j,p) # xs) = (let n = measure p in if n < m then select_min_main n j xs
    else select_min_main m i xs)"
| "select_min_main m i [] = i"

definition select_min :: "(nat × 'a) list  nat" where
  "select_min xs = (case xs of ((i,p) # ys)  (select_min_main (measure p) i ys))"

lemma select_min[simp]: "det_selection_fun select_min"
  unfolding det_selection_fun_def 
proof (intro allI impI)
  fix xs :: "(nat × 'a)list"
  assume "xs  []"
  then obtain i p ys where xs: "xs = ((i,p) # ys)" by (cases xs, auto)
  then obtain m where id: "select_min xs = select_min_main m i ys" unfolding select_min_def by auto
  have "i  fst ` set xs" "set ys  set xs" unfolding xs by auto
  thus "select_min xs  fst ` set xs" unfolding id
  proof (induct ys arbitrary: m i )
    case (Cons jp ys m i)
    obtain j p where jp: "jp = (j,p)" by force
    obtain k n where res: "select_min_main m i (jp # ys) = select_min_main n k ys" 
      and k: "k  fst ` set xs"
      using Cons(2-) unfolding jp by (cases "measure p < m"; force simp: Let_def)
    from Cons(1)[OF k, of n] Cons(3) 
    show ?case unfolding res by auto
  qed simp
qed
end

text ‹For the code equation we use the trivial mute and selection function as this does
  not impose any further class restrictions.›

lemma det_code_fst_sel_fun[code]: "det A = det_code fst_sel_fun trivial_mute_fun A" by simp

text ‹But we also provide specialiced functions for more specific carriers.›

definition field_mute_fun :: "'a :: field  'a  'a × 'a × 'a" where
  "field_mute_fun x y = (x/y,1,y)"

lemma field_mute_fun[simp,intro]: "mute_fun field_mute_fun"
  unfolding mute_fun_def field_mute_fun_def by auto

definition det_field :: "'a :: field mat  'a" where 
  "det_field A = det_code fst_sel_fun field_mute_fun A"

lemma det_field[simp]: "det_field = det"
  by (intro ext, auto simp: det_field_def)

definition gcd_mute_fun :: "'a :: ring_gcd  'a  'a × 'a × 'a" where
  "gcd_mute_fun x y = (let g = gcd x y in (x div g, y div g,g))"

lemma gcd_mute_fun[simp,intro]: "mute_fun gcd_mute_fun"
  unfolding mute_fun_def gcd_mute_fun_def by (auto simp: Let_def div_mult_swap mult.commute)

definition det_int :: "int mat  int" where 
  "det_int A = det_code (select_min (λ x. nat (abs x))) gcd_mute_fun A"

lemma det_int[simp]: "det_int = det"
  by (intro ext, auto simp: det_int_def)

definition det_field_poly :: "'a :: {field,field_gcd} poly mat  'a poly" where
  "det_field_poly A = det_code (select_min degree) gcd_mute_fun A"

lemma det_field_poly[simp]: "det_field_poly = det"
  by (intro ext, auto simp: det_field_poly_def)

end

Theory Show_Matrix

(*  
    Author:      René Thiemann 
                 Akihisa Yamada
    License:     BSD
*)
section ‹Converting Matrices to Strings›

text ‹We just instantiate matrices in the show-class by printing them as lists of lists.›

theory Show_Matrix
imports
  Show.Show
  Matrix
begin

definition shows_vec :: "'a :: show vec  shows" where
  "shows_vec v  shows (list_of_vec v)"

instantiation vec :: ("show") "show"
begin

definition "shows_prec p (v :: 'a vec)  shows_vec v"
definition "shows_list (vs :: 'a vec list)  shows_sep shows_vec (shows '', '') vs"

instance 
  by standard (simp_all add: shows_vec_def show_law_simps shows_prec_vec_def shows_list_vec_def)
end

definition shows_mat :: "'a :: show mat  shows" where
  "shows_mat A  shows (mat_to_list A)"

instantiation mat :: ("show") "show"
begin

definition "shows_prec p (A :: 'a mat)  shows_mat A"
definition "shows_list (As :: 'a mat list)  shows_sep shows_mat (shows '', '') As"

instance 
  by standard (simp_all add: shows_mat_def show_law_simps shows_prec_mat_def shows_list_mat_def)
end

end

Theory Char_Poly

(*  
    Author:      René Thiemann 
                 Akihisa Yamada
    License:     BSD
*)
section ‹Characteristic Polynomial›

text ‹We define eigenvalues, eigenvectors, and the characteristic polynomial. We further prove
  that the eigenvalues are exactly the roots of the characteristic polynomial.
  Finally, we apply the fundamental theorem of algebra to show that the characteristic polynomial
  of a complex matrix can always be represented as product of linear factors $x - a$.›
 
theory Char_Poly
imports 
  Polynomial_Factorization.Fundamental_Theorem_Algebra_Factorized
  Polynomial_Interpolation.Missing_Polynomial
  Polynomial_Interpolation.Ring_Hom_Poly
  Determinant
  Complex_Main
begin

definition eigenvector :: "'a :: comm_ring_1 mat  'a vec  'a  bool" where
  "eigenvector A v k = (v  carrier_vec (dim_row A)  v  0v (dim_row A)  A *v v = k v v)"
  
lemma eigenvector_pow: assumes A: "A  carrier_mat n n"
  and ev: "eigenvector A v (k :: 'a :: comm_ring_1)"
  shows "A ^m i *v v = k^i v v" 
proof -
  let ?G = "monoid_vec TYPE ('a) n"
  from A have dim: "dim_row A = n" by auto
  from ev[unfolded eigenvector_def dim]
  have v: "v  carrier_vec n" and Av: "A *v v = k v v" by auto
  interpret v: comm_group ?G by (rule comm_group_vec)
  show ?thesis
  proof (induct i)
    case 0
    show ?case using v dim by simp
  next
    case (Suc i)
    define P where "P = A ^m i"
    have P: "P  carrier_mat n n" using A unfolding P_def by simp
    have "A ^m Suc i = P * A" unfolding P_def by simp
    also have " *v v = P *v (A *v v)" using P A v by simp
    also have "A *v v = k v v" by (rule Av)
    also have "P *v (k v v) = k v (P *v v)" 
      by (rule eq_vecI, insert v P, auto)
    also have "(P *v v) = (k ^ i) v v" unfolding P_def by (rule Suc)
    also have "k v ((k ^ i) v v) = (k * k ^ i) v v"
      by (rule eq_vecI, insert v, auto)
    also have "k * k ^ i = k ^ (Suc i)" by auto
    finally show ?case .
  qed
qed
    


definition eigenvalue :: "'a :: comm_ring_1 mat  'a  bool" where
  "eigenvalue A k = ( v. eigenvector A v k)"


definition char_matrix :: "'a :: field mat  'a  'a mat" where
  "char_matrix A e = A + ((-e) m (1m (dim_row A)))"

lemma char_matrix_closed[simp]: "A  carrier_mat n n  char_matrix A e  carrier_mat n n"
  unfolding char_matrix_def by auto

lemma eigenvector_char_matrix: assumes A: "(A :: 'a :: field mat)  carrier_mat n n" 
  shows "eigenvector A v e = (v  carrier_vec n  v  0v n  char_matrix A e *v v = 0v n)"
proof -
  from A have dim: "dim_row A = n" "dim_col A = n" by auto
  {
    assume v: "v  carrier_vec n"
    hence dimv: "dim_vec v = n" by auto
    have "(A *v v = e v v) = (A *v v + (-e) v v = 0v n)" (is "?id1 = ?id2")
    proof
      assume ?id1
      from arg_cong[OF this, of "λ w. w + (-e) v v"]
      show ?id2 using A v by auto
    next
      assume ?id2
      have "A *v v + - e v v + e v v = A *v v" using A v by auto
      from arg_cong[OF ?id2, of "λ w. w + e v v", unfolded this]
      show ?id1 using A v by simp
    qed
    also have "(A *v v + (-e) v v) = char_matrix A e *v v" unfolding char_matrix_def
      by (rule eq_vecI, insert v A dim, auto simp: add_scalar_prod_distrib[of _ n])
    finally have "(A *v v = e v v) = (char_matrix A e *v v = 0v n)" .
  }
  thus ?thesis unfolding eigenvector_def dim by blast
qed

lemma eigenvalue_char_matrix: assumes A: "(A :: 'a :: field mat)  carrier_mat n n" 
  shows "eigenvalue A e = ( v. v  carrier_vec n  v  0v n  char_matrix A e *v v = 0v n)"
  unfolding eigenvalue_def eigenvector_char_matrix[OF A] .. 

definition find_eigenvector :: "'a::field mat  'a  'a vec" where
  "find_eigenvector A e = 
    find_base_vector (fst (gauss_jordan (char_matrix A e) (0m (dim_row A) 0)))"

lemma find_eigenvector: assumes A: "A  carrier_mat n n"
  and ev: "eigenvalue A e"
  shows "eigenvector A (find_eigenvector A e) e"
proof -
  define B where "B = char_matrix A e"
  from ev[unfolded eigenvalue_char_matrix[OF A]]  obtain v where
    v: "v  carrier_vec n" "v  0v n" and Bv: "B *v v = 0v n" unfolding B_def by auto
  have B: "B  carrier_mat n n" using A unfolding B_def by simp
  let ?z = "0m (dim_row A) 0"
  obtain C D where gauss: "gauss_jordan B ?z = (C,D)" by force
  define w where "w = find_base_vector C"
  have res: "find_eigenvector A e = w" unfolding w_def find_eigenvector_def Let_def gauss B_def[symmetric]
    by simp
  have "?z  carrier_mat n 0" using A by auto
  note gauss_0 = gauss_jordan[OF B this gauss] 
  hence C: "C  carrier_mat n n" by auto
  from gauss_0(1)[OF v(1)] Bv have Cv: "C *v v = 0v n" by auto
  {
    assume C: "C = 1m n"
    have False using id Cv v unfolding C by auto
  }
  hence C1: "C  1m n" by auto
  from find_base_vector_not_1[OF gauss_jordan_row_echelon[OF B gauss] C C1]
  have w: "w  carrier_vec n" "w  0v n" and id: "C *v w = 0v n" unfolding w_def by auto
  from gauss_0(1)[OF w(1)] id have Bw: "B *v w = 0v n" by simp
  from w Bw have "eigenvector A w e" 
    unfolding eigenvector_char_matrix[OF A] B_def by auto
  thus ?thesis unfolding res .
qed

lemma eigenvalue_imp_nonzero_dim: assumes "A  carrier_mat n n"
  and "eigenvalue A ev"
  shows "n > 0"
proof (cases n)
  case 0
  from assms obtain v where "eigenvector A v ev" unfolding eigenvalue_def by auto
  from this[unfolded eigenvector_def] assms 0 
  have "v  carrier_vec 0" "v  0v 0" by auto
  hence False by auto
  thus ?thesis by auto
qed simp
  
lemma eigenvalue_det: assumes A: "(A :: 'a :: field mat)  carrier_mat n n" shows
  "eigenvalue A e = (det (char_matrix A e) = 0)"
proof -
  from A have cA: "char_matrix A e  carrier_mat n n" by auto
  show ?thesis
    unfolding eigenvalue_char_matrix[OF A]
    unfolding id det_0_negate[OF cA] det_0_iff_vec_prod_zero[OF cA]
      eigenvalue_def by auto
qed

definition char_poly_matrix :: "'a :: comm_ring_1 mat  'a poly mat" where
  "char_poly_matrix A = (([:0,1:] m 1m (dim_row A)) + map_mat (λ a. [: - a :]) A)"    

lemma char_poly_matrix_closed[simp]: "A  carrier_mat n n  char_poly_matrix A  carrier_mat n n"
  unfolding char_poly_matrix_def by auto    

definition char_poly :: "'a :: comm_ring_1 mat  'a poly" where
  "char_poly A = (det (char_poly_matrix A))"    

lemmas char_poly_defs = char_poly_def char_poly_matrix_def

lemma (in comm_ring_hom) char_poly_matrix_hom: assumes A: "A  carrier_mat n n"
  shows "char_poly_matrix (math A) = map_mat (map_poly hom) (char_poly_matrix A)"
  unfolding char_poly_defs
  by (rule eq_matI, insert A, auto simp: smult_mat_def hom_distribs)

lemma (in comm_ring_hom) char_poly_hom: assumes A: "A  carrier_mat n n"
  shows "char_poly (map_mat hom A) = map_poly hom (char_poly A)"
proof -
  interpret map_poly_hom: map_poly_comm_ring_hom hom..
  show ?thesis
    unfolding char_poly_def map_poly_hom.hom_det[symmetric] char_poly_matrix_hom[OF A] ..
qed

context inj_comm_ring_hom
begin

lemma eigenvector_hom: assumes A: "A  carrier_mat n n"
  and ev: "eigenvector A v ev"
  shows "eigenvector (math A) (vech v) (hom ev)"
proof -
  let ?A = "math A" 
  let ?v = "vech v"
  let ?ev = "hom ev"
  from ev[unfolded eigenvector_def] A
  have v: "v  carrier_vec n" "v  0v n" "A *v v = ev v v" by auto
  from v(1) have v1: "?v  carrier_vec n" by simp
  from v(1-2) obtain i where "i < n" and "v $ i  0" by force
  with v(1) have "?v $ i  0" by auto
  hence v2: "?v  0v n" using i < n v(1) by force
  from arg_cong[OF v(3), of "vech", unfolded mult_mat_vec_hom[OF A v(1)] vec_hom_smult]
  have v3: "?A *v ?v = ?ev v ?v" .
  from v1 v2 v3
  show ?thesis unfolding eigenvector_def using A by auto
qed

lemma eigenvalue_hom: assumes A: "A  carrier_mat n n"
  and ev: "eigenvalue A ev"
  shows "eigenvalue (math A) (hom ev)"
  using eigenvector_hom[OF A, of _ ev] ev
  unfolding eigenvalue_def by auto

lemma eigenvector_hom_rev: assumes A: "A  carrier_mat n n"
  and ev: "eigenvector (math A) (vech v) (hom ev)"
  shows "eigenvector A v ev"
proof -
  let ?A = "math A" 
  let ?v = "vech v"
  let ?ev = "hom ev"
  from ev[unfolded eigenvector_def] A
  have v: "v  carrier_vec n" "?v  0v n" "?A *v ?v = ?ev v ?v" by auto
  from v(1-2) obtain i where "i < n" and "v $ i  0" by force
  with v(1) have "v $ i  0" by auto
  hence v2: "v  0v n" using i < n v(1) by force
  from vec_hom_inj[OF v(3)[folded mult_mat_vec_hom[OF A v(1)] vec_hom_smult]]
  have v3: "A *v v = ev v v" .
  from v(1) v2 v3
  show ?thesis unfolding eigenvector_def using A by auto
qed

end


lemma poly_det_cong: assumes A: "A  carrier_mat n n"
  and B: "B  carrier_mat n n"
  and poly: " i j. i < n  j < n  poly (B $$ (i,j)) k = A $$ (i,j)"
  shows "poly (det B) k = det A" 
proof -
  show ?thesis
  unfolding det_def'[OF A] det_def'[OF B] poly_sum poly_mult poly_prod
  proof (rule sum.cong[OF refl])
    fix x
    assume x: "x  {p. p permutes {0..<n}}"
    let ?l = "ka = 0..<n. poly (B $$ (ka, x ka)) k"
    let ?r = "i = 0..<n. A $$ (i, x i)"
    have id: "?l = ?r"
      by (rule prod.cong[OF refl poly], insert x, auto)
    show "poly (signof x) k * ?l = signof x * ?r" unfolding id signof_def by auto
  qed
qed

lemma char_poly_matrix: assumes A: "(A :: 'a :: field mat)  carrier_mat n n"
  shows "poly (char_poly A) k = det (- (char_matrix A k))"  unfolding char_poly_def
  by (rule poly_det_cong[of _ n], insert A, auto simp: char_poly_matrix_def char_matrix_def)

lemma eigenvalue_root_char_poly: assumes A: "(A :: 'a :: field mat)  carrier_mat n n"
  shows "eigenvalue A k  poly (char_poly A) k = 0" 
  unfolding eigenvalue_det[OF A] char_poly_matrix[OF A] 
  by (subst det_0_negate[of _ n], insert A, auto)

context
  fixes A :: "'a :: comm_ring_1 mat" and n :: nat
  assumes A: "A  carrier_mat n n"
  and ut: "upper_triangular A"
begin
lemma char_poly_matrix_upper_triangular: "upper_triangular (char_poly_matrix A)"
  using A ut unfolding upper_triangular_def char_poly_matrix_def by auto 

lemma char_poly_upper_triangular: 
  "char_poly A = ( a  diag_mat A. [:- a, 1:])"
proof -
  from A have cA: "char_poly_matrix A  carrier_mat n n" by simp
  show ?thesis
    unfolding char_poly_def det_upper_triangular [OF char_poly_matrix_upper_triangular cA]
    by (rule arg_cong[where f = prod_list], unfold list_eq_iff_nth_eq, insert cA A, auto simp: diag_mat_def
      char_poly_matrix_def)
qed
end

lemma map_poly_mult: assumes A: "A  carrier_mat nr n"
  and B: "B  carrier_mat n nc"
  shows 
    "map_mat (λ a. [: a :]) (A * B) = map_mat (λ a. [: a :]) A * map_mat (λ a. [: a :]) B" (is "?id")
    "map_mat (λ a. [: a :] * p) (A * B) = map_mat (λ a. [: a :] * p) A * map_mat (λ a. [: a :]) B" (is "?left")
    "map_mat (λ a. [: a :] * p) (A * B) = map_mat (λ a. [: a :]) A * map_mat (λ a. [: a :] * p) B" (is "?right")
proof -
  from A B have dim: "dim_row A = nr" "dim_col A = n" "dim_row B = n" "dim_col B = nc" by auto
  {
    fix i j
    have "i < nr  j < nc  
      row (map_mat (λa. [:a:]) A) i  col (map_mat (λa. [:a:]) B) j = [:(row A i  col B j):]"
      unfolding scalar_prod_def
      by (auto simp: dim ac_simps, induct n, auto)  
  } note id = this 
  {
    fix i j
    have "i < nr  j < nc  
      [:(row A i  col B j):] * p = row (map_mat (λ a. [: a :] * p) A) i  col (map_mat (λa. [:a:]) B) j"
      unfolding scalar_prod_def
      by (auto simp: dim ac_simps smult_sum) 
  } note left = this 
  {
    fix i j
    have "i < nr  j < nc  
      [:(row A i  col B j):] * p = row (map_mat (λ a. [: a :]) A) i  col (map_mat (λa. [:a:] * p) B) j"
      unfolding scalar_prod_def
      by (auto simp: dim ac_simps smult_sum) 
  } note right = this 
  show ?id
    by (rule eq_matI, insert id, auto simp: dim)
  show ?left
    by (rule eq_matI, insert left, auto simp: dim)
  show ?right
    by (rule eq_matI, insert right, auto simp: dim)
qed

lemma char_poly_similar: assumes "similar_mat A (B :: 'a :: comm_ring_1 mat)"
  shows "char_poly A = char_poly B"
proof -
  from similar_matD[OF assms] obtain n P Q where
  carr: "{A, B, P, Q}  carrier_mat n n" (is "_  ?C")
  and PQ: "P * Q = 1m n" 
  and AB: "A = P * B * Q" by auto
  hence A: "A  ?C" and B: "B  ?C" and P: "P  ?C" and Q: "Q  ?C" by auto
  let ?m = "λ a. [: -a :]"
  let ?P = "map_mat (λ a. [: a :]) P"
  let ?Q = "map_mat (λ a. [: a :]) Q"
  let ?B = "map_mat ?m B"
  let ?I = "map_mat (λ a. [: a :]) (1m n)"
  let ?XI = "[:0, 1:] m 1m n"
  from A B have dim: "dim_row A = n" "dim_row B = n" by auto
  have cong: " x y z. x = y  x * z = y * z" by auto
  have id: "?m = (λ a :: 'a. [: a :] * [: -1 :])" by (intro ext, auto)
  have "char_poly A = det (?XI + map_mat (λa. [:- a:]) (P * B * Q))" unfolding 
    char_poly_defs dim 
    by (simp add: AB)
  also have "?XI = ?P * ?XI * ?Q" (is "_ = ?left")
  proof -
    have "?P * ?XI = [:0, 1:] m (?P * 1m n)" 
      by (rule mult_smult_distrib[of _ n n _ n], insert P, auto)
    also have "?P * 1m n = ?P" using P by simp
    also have "([: 0, 1:] m ?P) * ?Q = [: 0, 1:] m (?P * ?Q)"
      by (rule mult_smult_assoc_mat, insert P Q, auto)
    also have "?P * ?Q = ?I" unfolding PQ[symmetric]
      by (rule map_poly_mult[symmetric, OF P Q])
    also have "[: 0, 1:] m ?I = ?XI"
      by rule auto
    finally show ?thesis ..
  qed
  also have "map_mat ?m (P * B * Q) = ?P * ?B * ?Q" (is "_ = ?right")
    unfolding id
    by (subst map_poly_mult[OF mult_carrier_mat[OF P B] Q],
      subst map_poly_mult(3)[OF P B], simp)
  also have "?left + ?right = (?P * ?XI + ?P * ?B) * ?Q"
    by (rule add_mult_distrib_mat[symmetric, of _ n n], insert B P Q, auto)
  also have "?P * ?XI + ?P * ?B = ?P * (?XI + ?B)"
    by (rule mult_add_distrib_mat[symmetric, of _ n n], insert B P Q, auto)
  also have "det (?P * (?XI + ?B) * ?Q) = det ?P * det (?XI + ?B) * det ?Q"
    by (rule trans[OF det_mult[of _ n] cong[OF det_mult]], insert P Q B, auto)
  also have " = (det ?P * det ?Q) * det (?XI + ?B)" by (simp add: ac_simps)
  also have "det (?XI + ?B) = char_poly B" unfolding char_poly_defs dim by simp
  also have "det ?P * det ?Q = det (?P * ?Q)"
    by (rule det_mult[symmetric], insert P Q, auto)
  also have "?P * ?Q = ?I" unfolding PQ[symmetric]
    by (rule map_poly_mult[symmetric, OF P Q])
  also have "det  = prod_list (diag_mat ?I)"
    by (rule det_upper_triangular[of _ n], auto)
  also have " = 1" unfolding prod_list_diag_prod
    by (rule prod.neutral) simp
  finally show ?thesis by simp
qed

lemma degree_signof_mult[simp]: "degree (signof p * q) = degree q"
  by (cases "sign p = 1", auto simp: signof_def)

lemma degree_monic_char_poly: assumes A: "A  carrier_mat n n"
  shows "degree (char_poly A) = n  coeff (char_poly A) n = 1"
proof -
  from A have A': "[:0, 1:] m 1m (dim_row A) + map_mat (λa. [:- a:]) A  carrier_mat n n" by auto
  from A have dA: "dim_row A = n" by simp
  show ?thesis
    unfolding char_poly_defs det_def'[OF A']
  proof (rule degree_lcoeff_sum[of _ id], auto simp: finite_permutations permutes_id dA)
    have both: "degree (i = 0..<n. ([:0, 1:] m 1m n + map_mat (λa. [:- a:]) A) $$ (i, i)) = n 
      coeff (i = 0..<n. ([:0, 1:] m 1m n + map_mat (λa. [:- a:]) A) $$ (i, i)) n = 1"
      by (rule degree_prod_monic, insert A, auto)
    from both show "degree (i = 0..<n. ([:0, 1:] m 1m n + map_mat (λa. [:- a:]) A) $$ (i, i)) = n" ..
    from both show "coeff (i = 0..<n. ([:0, 1:] m 1m n + map_mat (λa. [:- a:]) A) $$ (i, i)) n = 1" ..
  next
    fix p
    assume p: "p permutes {0..<n}"
      and "p  id"
    then obtain i where i: "i < n" and pi: "p i  i"
      by (metis atLeastLessThan_iff order_refl permutes_natset_le)
    show "degree (i = 0..<n. ([:0, 1:] m 1m n + map_mat (λa. [:- a:]) A) $$ (i, p i)) < n"
      by (rule degree_prod_sum_lt_n[OF _ i], insert p i pi A, auto)
  qed
qed

lemma char_poly_factorized: fixes A :: "complex mat"
  assumes A: "A  carrier_mat n n"
  shows " as. char_poly A = ( a  as. [:- a, 1:])  length as = n"
proof -
  let ?p = "char_poly A"
  from fundamental_theorem_algebra_factorized[of ?p] obtain as
  where "Polynomial.smult (coeff ?p (degree ?p)) (aas. [:- a, 1:]) = ?p" by blast
  also have "coeff ?p (degree ?p) = 1" using degree_monic_char_poly[OF A] by simp
  finally have cA: "?p = (aas. [:- a, 1:])" by simp
  from degree_monic_char_poly[OF A] have "degree ?p = n" ..
  with degree_linear_factors[of uminus as, folded cA] have "length as = n" by auto
  with cA show ?thesis by blast
qed

lemma char_poly_four_block_zeros_col: assumes A1: "A1  carrier_mat 1 1"
  and A2: "A2  carrier_mat 1 n" and A3: "A3  carrier_mat n n"
  shows "char_poly (four_block_mat A1 A2 (0m n 1) A3) = char_poly A1 * char_poly A3" 
    (is "char_poly ?A = ?cp1 * ?cp3")
proof -
  let ?cm = "λ A. [:0, 1:] m 1m (dim_row A) +
         map_mat (λa. [:- a:]) A"
  let ?B2 = "map_mat (λa. [:- a:]) A2"
  have "char_poly ?A = det (?cm ?A)"
    unfolding char_poly_defs using A1 A3 by simp
  also have "?cm ?A = four_block_mat (?cm A1) ?B2 (0m n 1) (?cm A3)"
    by (rule eq_matI, insert A1 A2 A3, auto simp: one_poly_def)
  also have "det  = det (?cm A1) * det (?cm A3)"
    by (rule det_four_block_mat_lower_left_zero_col[OF _ _ refl], insert A1 A2 A3, auto)
  also have " = ?cp1 * ?cp3" unfolding char_poly_defs ..
  finally show ?thesis .
qed

lemma char_poly_transpose_mat[simp]: assumes A: "A  carrier_mat n n"
  shows "char_poly (transpose_mat A) = char_poly A"
proof -
  let ?A = "[:0, 1:] m 1m (dim_row A) + map_mat (λa. [:- a:]) A"
  have A': "?A  carrier_mat n n" using A by auto
  show ?thesis unfolding char_poly_defs
    by (subst det_transpose[symmetric, OF A'], rule arg_cong[of _ _ det],
    insert A, auto)
qed

lemma pderiv_char_poly: fixes A :: "'a :: idom mat" 
  assumes A: "A  carrier_mat n n" 
  shows "pderiv (char_poly A) = (i < n. char_poly (mat_delete A i i))"
proof -
  let ?det = Determinant.det
  let ?m = "map_mat (λa. [:- a:])" 
  let ?lam = "[:0, 1:] m 1m n :: 'a poly mat" 
  from A have id: "dim_row A = n" by auto  

  define mA where "mA = ?m A"
  define lam where "lam = ?lam" 
  let ?sum = "lam + mA" 
  define Sum where "Sum = ?sum" 
  have mA: "mA  carrier_mat n n" and 
    lam: "lam  carrier_mat n n" and
    Sum: "Sum  carrier_mat n n" 
    using A unfolding mA_def Sum_def lam_def by auto
  let ?P = "{p. p permutes {0..<n}}" 
  let ?e = "λ p. (i = 0..<n. Sum $$ (i, p i))" 
  let ?f = "λ p a. signof p * (i{0..<n} - {a}. Sum $$ (i, p i)) * pderiv (Sum $$ (a, p a))" 
  let ?g = "λ p a. signof p * (i{0..<n} - {a}. Sum $$ (i, p i))" 
  define f where "f = ?f" 
  define g where "g = ?g" 
  {
    fix p
    assume p: "p permutes {0 ..< n}" 
    have "pderiv (signof p :: 'a poly) = 0" unfolding signof_def by (simp add: pderiv_minus) 
    hence "pderiv (signof p * ?e p) = signof p * pderiv (i = 0..<n. Sum $$ (i, p i))" 
      unfolding pderiv_mult by auto
    also have "signof p * pderiv (i = 0..<n. Sum $$ (i, p i)) = 
       (a = 0..<n. f p a)" 
      unfolding pderiv_prod sum_distrib_left f_def by (simp add: ac_simps)
    also note calculation
  } note to_f = this
  {
    fix a
    assume a: "a < n" 
    have Psplit: "?P = {p. p permutes {0..<n}  p a = a}  (?P - {p. p a = a})" (is "_ = ?Pa  ?Pz") by auto 
    {
      fix p
      assume p: "p permutes {0 ..< n}" "p a  a"
      hence "pderiv (Sum $$ (a, p a)) = 0" unfolding Sum_def lam_def mA_def using a p A by auto
      hence "f p a = 0" unfolding f_def by auto
    } note 0 = this
    {
      fix p
      assume p: "p permutes {0 ..< n}" "p a = a"
      hence "pderiv (Sum $$ (a, p a)) = 1" unfolding Sum_def lam_def mA_def using a p A
        by (auto simp: pderiv_pCons)
      hence "f p a = g p a" unfolding f_def g_def by auto
    } note fg = this
    let ?n = "n - 1" 
    from a have n: "Suc ?n = n" by simp
    let ?B = "[:0, 1:] m 1m ?n + ?m (mat_delete A a a)" 
    have B: "?B  carrier_mat ?n ?n" using A by auto
    have "sum (λ p. f p a) ?P = sum (λ p. f p a) ?Pa + sum (λ p. f p a) ?Pz" 
      by (subst sum.union_disjoint[symmetric], auto simp: finite_permutations Psplit[symmetric])
    also have " = sum (λ p. f p a) ?Pa" 
      by (subst (2) sum.neutral, insert 0, auto)
    also have " = sum (λ p. g p a) ?Pa" 
      by (rule sum.cong, auto simp: fg)
    also have " = ?det ?B"
      unfolding det_def'[OF B] 
      unfolding permutation_fix[of a ?n a, unfolded n, OF a a]
      unfolding sum.reindex[OF permutation_insert_inj_on[of a ?n a, unfolded n, OF a a]] o_def
    proof (rule sum.cong[OF refl])
      fix p
      let ?Q = "{p. p permutes {0..<?n}}" 
      assume "p  ?Q"      
      hence p: "p permutes {0 ..< ?n}" by auto
      let ?p = "permutation_insert a a p"
      let ?i = "insert_index a" 
      have sign: "signof ?p = signof p" 
        unfolding signof_permutation_insert[OF p, unfolded n, OF a a] by simp
      show "g (permutation_insert a a p) a = signof p * (i = 0..<?n. ?B $$ (i, p i))" 
        unfolding g_def sign
      proof (rule arg_cong[of _ _ "(*) (signof p)"])
        have "(i{0..<n} - {a}. Sum $$ (i, ?p i)) = 
           prod (($$) Sum) ((λx. (x, ?p x)) ` ({0..<n} - {a}))"
          unfolding prod.reindex[OF inj_on_convol_ident, of _ ?p] o_def ..
        also have " = ( ii  {(i', ?p i') |i'. i'  {0..<n} - {a}}. Sum $$ ii)" 
          by (rule prod.cong, auto)
        also have " = prod (($$) Sum) ((λ i. (?i i, ?i (p i))) ` {0 ..< ?n})" 
          unfolding Determinant.foo[of a ?n a, unfolded n, OF a a p]
          by (rule arg_cong[of _ _ "prod _"], auto) 
        also have " = prod (λ i. Sum $$ (?i i, ?i (p i))) {0 ..< ?n}"
        proof (subst prod.reindex, unfold o_def)
          show "inj_on (λi. (?i i, ?i (p i))) {0..<?n}" using insert_index_inj_on[of a]
            by (auto simp: inj_on_def)
        qed simp
        also have " = (i = 0..<?n. ?B $$ (i, p i))" 
        proof (rule prod.cong[OF refl], rename_tac i)
          fix j
          assume "j  {0 ..< ?n}"
          hence j: "j < ?n" by auto
          with p have pj: "p j < ?n" by auto
          from j pj have jj: "?i j < n" "?i (p j) < n" by (auto simp: insert_index_def)
          let ?jj = "(?i j, ?i (p j))" 
          note index_adj = mat_delete_index[of _ ?n, unfolded n, OF _ a a j pj]
          have "Sum $$ ?jj = lam $$ ?jj + mA $$ ?jj" unfolding Sum_def using jj A lam mA by auto
          also have " = ?B $$ (j, p j)" 
            unfolding index_adj[OF mA] index_adj[OF lam] using j pj A
            by (simp add: mA_def lam_def mat_delete_def)
          finally show "Sum $$ ?jj = ?B $$ (j, p j)" .
        qed
        finally 
        show "(i{0..<n} - {a}. Sum $$ (i, ?p i)) = (i = 0..<?n. ?B $$ (i, p i))" .
      qed
    qed
    also have " = char_poly (mat_delete A a a)" unfolding char_poly_def char_poly_matrix_def
      using A by simp
    also note calculation
  } note to_char_poly = this
  have "pderiv (char_poly A) = pderiv (?det Sum)" 
    unfolding char_poly_def char_poly_matrix_def id lam_def mA_def Sum_def by auto
  also have " = sum (λ p. pderiv (signof p * ?e p)) ?P" unfolding det_def'[OF Sum]
    pderiv_sum by (rule sum.cong, auto)
  also have " = sum (λ p. (a = 0..<n. f p a)) ?P" 
    by (rule sum.cong[OF refl], subst to_f, auto)
  also have " = (a = 0..<n. sum (λ p. f p a) ?P)" 
    by (rule sum.swap) 
  also have " = (a <n. char_poly (mat_delete A a a))" 
    by (rule sum.cong, auto simp: to_char_poly)
  finally show ?thesis .
qed    

lemma char_poly_0_column: fixes A :: "'a :: idom mat" 
  assumes 0: " j. j < n  A $$ (j,i) = 0" 
  and A: "A  carrier_mat n n" 
  and i: "i < n"
shows "char_poly A = monom 1 1 * char_poly (mat_delete A i i)" 
proof -
  let ?n = "n - 1" 
  let ?A = "mat_delete A i i" 
  let ?sum = "[:0, 1:] m 1m n + map_mat (λa. [:- a:]) A" 
  define Sum where "Sum = ?sum" 
  let ?f = "λ j. Sum $$ (j, i) * cofactor Sum j i" 
  have Sum: "Sum  carrier_mat n n" using A unfolding Sum_def by auto
  from A have id: "dim_row A = n" by auto  
  have "char_poly A = (j<n. ?f j)" 
    unfolding char_poly_def[of A] char_poly_matrix_def 
    using laplace_expansion_column[OF Sum i] unfolding Sum_def using A by simp
  also have " = ?f i + sum ?f ({..<n} - {i})" 
    by (rule sum.remove[of _ i], insert i, auto)
  also have " = ?f i" 
  proof (subst sum.neutral, intro ballI)
    fix j
    assume "j  {..<n} - {i}" 
    hence j: "j < n" and ji: "j  i" by auto
    show "?f j = 0" unfolding Sum_def using ji j i A 0[OF j] by simp
  qed simp
  also have "?f i = [:0, 1:] * (cofactor Sum i i)" 
    unfolding Sum_def using i A 0[OF i] by simp
  also have "cofactor Sum i i = det (mat_delete Sum i i)" 
    unfolding cofactor_def by simp
  also have " = char_poly ?A" 
    unfolding char_poly_def char_poly_matrix_def Sum_def
  proof (rule arg_cong[of _ _ det])
    show "mat_delete ?sum i i = [:0, 1:] m 1m (dim_row ?A) + map_mat (λa. [:- a:]) ?A"
      using i A by (auto simp: mat_delete_def)
  qed
  also have "[:0, 1:] = (monom 1 1 :: 'a poly)" by (rule x_as_monom)
  finally show ?thesis .
qed

definition mat_erase :: "'a :: zero mat  nat  nat  'a mat" where
  "mat_erase A i j = Matrix.mat (dim_row A) (dim_col A) 
     (λ (i',j'). if i' = i  j' = j then 0 else A $$ (i',j'))"  

lemma mat_erase_carrier[simp]: "(mat_erase A i j)  carrier_mat nr nc  A  carrier_mat nr nc" 
  unfolding mat_erase_def carrier_mat_def by simp

lemma pderiv_char_poly_mat_erase: fixes A :: "'a :: idom mat" 
  assumes A: "A  carrier_mat n n" 
  shows "monom 1 1 * pderiv (char_poly A) = (i < n. char_poly (mat_erase A i i))"
proof -
  show ?thesis unfolding pderiv_char_poly[OF A] sum_distrib_left
  proof (rule sum.cong[OF refl])
    fix i
    assume "i  {..<n}" 
    hence i: "i < n" by simp
    have mA: "mat_erase A i i  carrier_mat n n" using A by simp
    show "monom 1 1 * char_poly (mat_delete A i i) = char_poly (mat_erase A i i)"
      by (subst char_poly_0_column[OF _ mA i], (insert i A, force simp: mat_erase_def),
      rule arg_cong[of _ _ "λ x. f * char_poly x" for f],
      auto simp: mat_delete_def mat_erase_def)
  qed
qed
    
end

Theory Jordan_Normal_Form

(*  
    Author:      René Thiemann 
                 Akihisa Yamada
    License:     BSD
*)
section ‹Jordan Normal Form›

text ‹This theory defines Jordan normal forms (JNFs) in a sparse representation, i.e., 
  as block-diagonal matrices. We also provide a closed formula for powers of JNFs, 
  which allows to estimate the growth rates of JNFs.›

theory Jordan_Normal_Form
imports 
  Matrix
  Char_Poly
  Polynomial_Interpolation.Missing_Unsorted
begin

definition jordan_block :: "nat  'a :: {zero,one}  'a mat" where 
  "jordan_block n a = mat n n (λ (i,j). if i = j then a else if Suc i = j then 1 else 0)"

lemma jordan_block_index[simp]: "i < n  j < n  
  jordan_block n a $$ (i,j) = (if i = j then a else if Suc i = j then 1 else 0)"
  "dim_row (jordan_block n k) = n"
  "dim_col (jordan_block n k) = n"
  unfolding jordan_block_def by auto

lemma jordan_block_carrier[simp]: "jordan_block n k  carrier_mat n n" 
  unfolding carrier_mat_def by auto

lemma jordan_block_char_poly: "char_poly (jordan_block n a) = [: -a, 1:]^n"
  unfolding char_poly_defs by (subst det_upper_triangular[of _ n], auto simp: prod_list_diag_prod)

lemma jordan_block_pow_carrier[simp]:
  "jordan_block n a ^m r  carrier_mat n n" by auto
lemma jordan_block_pow_dim[simp]:
  "dim_row (jordan_block n a ^m r) = n" "dim_col (jordan_block n a ^m r) = n" by auto

lemma jordan_block_pow: "(jordan_block n (a :: 'a :: comm_ring_1)) ^m r = 
  mat n n (λ (i,j). if i  j then of_nat (r choose (j - i)) * a ^ (r + i - j) else 0)"
proof (induct r)
  case 0
  {
    fix i j :: nat
    assume "i  j" "i  j"
    hence "j - i > 0" by auto
    hence "0 choose (j - i) = 0" by simp
  } note [simp] = this
  show ?case
    by (simp, rule eq_matI, auto)
next
  case (Suc r)
  let ?jb = "jordan_block n a"
  let ?rij = "λ r i j. of_nat (r choose (j - i)) * a ^ (r + i - j)"
  let ?v = "λ i j. if i  j then of_nat (r choose (j - i)) * a ^ (r + i - j) else 0"
  have "?jb ^m Suc r = mat n n (λ (i,j). if i  j then ?rij r i j else 0) * ?jb" by (simp add: Suc)
  also have " = mat n n (λ (i,j). if i  j then ?rij (Suc r) i j else 0)"
  proof -
    {
      fix j
      assume j: "j < n"
      hence col: "col (jordan_block n a) j = vec n (λi. if i = j then a else if Suc i = j then 1 else 0)"
        unfolding jordan_block_def col_mat[OF j] by simp
      fix f
      have "vec n f  col (jordan_block n a) j = (f j * a + (if j = 0 then 0 else f (j - 1)))"
      proof -
        define p where "p = (λ i. vec n f $ i * col (jordan_block n a) j $ i)"
        have "vec n f  col (jordan_block n a) j = (i = 0 ..< n. p i)"
          unfolding scalar_prod_def p_def by simp
        also have " = p j + sum p ({0 ..< n} - {j})" using j
          by (subst sum.remove[of _ j], auto)
        also have "p j = f j * a" unfolding p_def col using j by auto
        also have "sum p ({0 ..< n} - {j}) = (if j = 0 then 0 else f (j - 1))"
        proof (cases j)
          case 0
          have "sum p ({0 ..< n} - {j}) = 0"
            by (rule sum.neutral, auto simp: p_def col 0)
          thus ?thesis using 0 by simp
        next
          case (Suc jj)
          with j have jj: "jj  {0 ..< n} - {j}" by auto
          have "sum p ({0 ..< n} - {j}) = p jj + sum p ({0 ..< n} - {j} - {jj})"
            by (subst sum.remove[OF _ jj], auto)
          also have "p jj = f (j - 1)" unfolding p_def col using jj
            by (auto simp: Suc)
          also have "sum p ({0 ..< n} - {j} - {jj}) = 0"
            by (rule sum.neutral, auto simp: p_def col, auto simp: Suc)
          finally show ?thesis unfolding Suc by simp
        qed
        finally show ?thesis .
      qed
    } note scalar_to_sum = this
    {
      fix i j
      assume i: "i < n" and ij: "i > j"
      hence j: "j < n" by auto
      have "vec n (?v i)  col (jordan_block n a) j = 0"
        unfolding scalar_to_sum[OF j] using ij i j by auto
    } note easy_case = this
    {
      fix i j
      assume j: "j < n" and ij: "i  j"
      hence i: "i < n" and id: " p q. (if i  j then p else q) = p" by auto
      have "vec n (?v i)  col (jordan_block n a) j =
        (of_nat (r choose (j - i)) * (a ^ (Suc (r + i - j)))) +
          (if j = 0 then 0
         else if i  j - 1 then of_nat (r choose (j - 1 - i)) * a ^ (r + i - (j - 1)) else 0)"
      unfolding scalar_to_sum[OF j]
      using ij by simp
      also have " = of_nat (Suc r choose (j - i)) * a ^ (Suc (r + i) - j)"
      proof (cases j)
        case (Suc jj)
        {
          assume "i  Suc jj" and "¬ i  jj"
          hence "i = Suc jj" by auto 
          hence "a * a ^ (r + i - Suc jj) = a ^ (r + i - jj)" by simp
        } 
        moreover
        {
          assume ijj: "i  jj"
          have "of_nat (r choose (Suc jj - i)) * (a * a ^ (r + i - Suc jj)) 
          + of_nat (r choose (jj - i)) * a ^ (r + i - jj) =
            of_nat (Suc r choose (Suc jj - i)) * a ^ (r + i - jj)"
          proof (cases "r + i < jj")
            case True
            hence gt: "jj - i > r" "Suc jj - i > r" "Suc jj - i > Suc r" by auto
            show ?thesis 
              unfolding binomial_eq_0[OF gt(1)] binomial_eq_0[OF gt(2)] binomial_eq_0[OF gt(3)]
              by simp
          next
            case False 
            hence ge: "r + i  jj" by simp
            show ?thesis
            proof (cases "jj = r + i")
              case True
              have gt: "r < Suc r" by simp
              show ?thesis unfolding True by (simp add: binomial_eq_0[OF gt])
            next
              case False
              with ge have lt: "jj < r + i" by auto
              hence "r + i - jj = Suc (r + i - Suc jj)" by simp 
              hence prod: "a * a ^ (r + i - Suc jj) = a ^ (r + i - jj)" by simp
              from ijj have id: "Suc jj - i = Suc (jj - i)" by simp
              have binom: "Suc r choose (Suc jj - i) = 
                r choose (Suc jj - i) + (r choose (jj - i))"
                unfolding id
                by (subst binomial_Suc_Suc, simp)
              show ?thesis unfolding prod binom  
                by (simp add: field_simps)
            qed
          qed
        }
        ultimately show ?thesis using ij unfolding Suc by auto
      qed auto
      finally have "vec n (?v i)  col (jordan_block n a) j 
        = of_nat (Suc r choose (j - i)) * a ^ (Suc (r + i) - j)" .
    } note main_case = this
    show ?thesis
      by (rule eq_matI, insert easy_case main_case, auto)
  qed
  finally show ?case by simp
qed

definition jordan_matrix :: "(nat × 'a :: {zero,one})list  'a mat" where
  "jordan_matrix n_as = diag_block_mat (map (λ (n,a). jordan_block n a) n_as)"

lemma jordan_matrix_dim[simp]: 
  "dim_row (jordan_matrix n_as) = sum_list (map fst n_as)"
  "dim_col (jordan_matrix n_as) = sum_list (map fst n_as)"
  unfolding jordan_matrix_def
  by (subst dim_diag_block_mat, auto, (induct n_as, auto simp: Let_def)+)

lemma jordan_matrix_carrier[simp]: 
  "jordan_matrix n_as  carrier_mat (sum_list (map fst n_as)) (sum_list (map fst n_as))"
  unfolding carrier_mat_def by auto

lemma jordan_matrix_upper_triangular: "i < sum_list (map fst n_as)
   j < i  jordan_matrix n_as $$ (i,j) = 0"
  unfolding jordan_matrix_def
  by (rule diag_block_upper_triangular, auto simp: jordan_matrix_def[symmetric])

lemma jordan_matrix_pow: "(jordan_matrix n_as) ^m r = 
  diag_block_mat (map (λ (n,a). (jordan_block n a) ^m r) n_as)"
  unfolding jordan_matrix_def
  by (subst diag_block_pow_mat, force, rule arg_cong[of _ _ diag_block_mat], auto)

lemma jordan_matrix_char_poly: 
  "char_poly (jordan_matrix n_as) = ((n, a)n_as. [:- a, 1:] ^ n)"
proof -
  let ?n = "sum_list (map fst n_as)"
  have "diag_mat
     ([:0, 1:] m 1m (sum_list (map fst n_as)) + map_mat (λa. [:- a:]) (jordan_matrix n_as)) =
    concat (map (λ(n, a). replicate n [:- a, 1:]) n_as)" unfolding jordan_matrix_def
  proof (induct n_as)
    case (Cons na n_as)
    obtain n a where na: "na = (n,a)" by force
    let ?n2 = "sum_list (map fst n_as)"
    note fbo = four_block_one_mat
    note mz = zero_carrier_mat
    note mo = one_carrier_mat
    have mA: " A. A  carrier_mat (dim_row A) (dim_col A)" unfolding carrier_mat_def by auto
    let ?Bs = "map (λ(x, y). jordan_block x y) n_as"
    let ?B = "diag_block_mat ?Bs"
    from jordan_matrix_dim[of n_as, unfolded jordan_matrix_def]
    have dimB: "dim_row ?B = ?n2" "dim_col ?B = ?n2" by auto
    hence B: "?B  carrier_mat ?n2 ?n2" unfolding carrier_mat_def by simp
    show ?case unfolding na fbo
    apply (simp add: Let_def fbo[symmetric] del: fbo)
    apply (subst smult_four_block_mat[OF mo mz mz mo])
    apply (subst map_four_block_mat[OF jordan_block_carrier mz mz mA])
    apply (subst add_four_block_mat[of _ n n _ ?n2 _ ?n2], auto simp: dimB B)
    apply (subst diag_four_block_mat[of _ n _ ?n2], auto simp: dimB B)
    apply (subst Cons, auto simp: jordan_block_def diag_mat_def, 
      intro nth_equalityI, auto)
    done
  qed (force simp: diag_mat_def)
  also have "prod_list ... = ((n, a)n_as. [:- a, 1:] ^ n)"
    by (induct n_as, auto)
  finally
  show ?thesis unfolding char_poly_defs
    by (subst det_upper_triangular[of _ ?n], auto simp: jordan_matrix_upper_triangular)
qed

definition jordan_nf :: "'a :: semiring_1 mat  (nat × 'a)list  bool" where
  "jordan_nf A n_as  (0  fst ` set n_as  similar_mat A (jordan_matrix n_as))"

lemma jordan_nf_powE: assumes A: "A  carrier_mat n n" and jnf: "jordan_nf A n_as" 
  obtains P Q where "P  carrier_mat n n" "Q  carrier_mat n n" and 
  "char_poly A = ((na, a)n_as. [:- a, 1:] ^ na)"
  " k. A ^m k = P * (jordan_matrix n_as)^m k * Q"
proof -
  from A have dim: "dim_row A = n" by auto
  assume obt: "P Q. P  carrier_mat n n  Q  carrier_mat n n  
    char_poly A = ((na, a)n_as. [:- a, 1:] ^ na)  
    (k. A ^m k = P * jordan_matrix n_as ^m k * Q)  thesis"
  from jnf[unfolded jordan_nf_def] obtain P Q where
    simw: "similar_mat_wit A (jordan_matrix n_as) P Q"
    and sim: "similar_mat A (jordan_matrix n_as)" unfolding similar_mat_def by blast
  show thesis
  proof (rule obt)
    show " k. A ^m k = P * jordan_matrix n_as ^m k * Q"
      by (rule similar_mat_wit_pow_id[OF simw])
    show "char_poly A = ((na, a)n_as. [:- a, 1:] ^ na)"
      unfolding char_poly_similar[OF sim] jordan_matrix_char_poly ..    
  qed (insert simw[unfolded similar_mat_wit_def Let_def dim], auto)
qed

lemma choose_poly_bound: assumes "i  d"
  shows "r choose i  max 1 (r^d)"
proof (cases "i  r")
  case False
  hence "r choose i = 0" by simp
  thus ?thesis by arith
next
  case True
  show ?thesis
  proof (cases r)
    case (Suc rr)
    from binomial_le_pow[OF True] have "r choose i  r ^ i" by simp
    also have "  r^d" using power_increasing[OF i  d, of r] Suc by auto
    finally show ?thesis by simp
  qed (insert True, simp)
qed  

context
  fixes b :: "'a :: archimedean_field"
  assumes b: "0 < b" "b < 1"
begin
      
lemma poly_exp_constant_bound: " p.  x. c * b ^ x * of_nat x ^ deg  p" 
proof (cases "c  0")
  case True
  show ?thesis
    by (rule exI[of _ 0], intro allI, 
    rule mult_nonpos_nonneg[OF mult_nonpos_nonneg[OF True]], insert b, auto)
next
  case False
  hence c: "c  0" by simp
  from poly_exp_bound[OF b, of deg] obtain p where " x. b ^ x * of_nat x ^ deg  p" by auto
  from mult_left_mono[OF this c]
  show ?thesis by (intro exI[of _ "c * p"], auto simp: ac_simps)
qed

lemma poly_exp_max_constant_bound: " p.  x. c * b ^ x * max 1 (of_nat x ^ deg)  p" 
proof -
  from poly_exp_constant_bound[of c deg] obtain p where
    p: " x. c * b ^ x * of_nat x ^ deg  p" by auto
  show ?thesis
  proof (rule exI[of _ "max p c"], intro allI)
    fix x
    let ?exp = "of_nat x ^ deg :: 'a"
    show "c * b ^ x * max 1 ?exp  max p c"
    proof (cases "x = 0")
      case False
      hence "?exp  of_nat 0" by simp
      hence "?exp  1" by (metis less_one not_less of_nat_1 of_nat_less_iff of_nat_power)
      hence "max 1 ?exp = ?exp" by simp
      thus ?thesis using p[of x] by simp
    qed (cases deg, auto)
  qed
qed
end

context
  fixes a :: "'a :: real_normed_field"
begin
lemma jordan_block_bound: 
  assumes i: "i < n" and j: "j < n"
  shows "norm ((jordan_block n a ^m k) $$ (i,j)) 
     norm a ^ (k + i - j) * max 1 (of_nat k ^ (n - 1))"
    (is "?lhs  ?rhs")
proof -
  have id: "(jordan_block n a ^m k) $$ (i,j) = (if i  j then of_nat (k choose (j - i)) * a ^ (k + i - j) else 0)"
    unfolding jordan_block_pow using i j by auto
  from i j have diff: "j - i  n - 1" by auto
  show ?thesis
  proof (cases "i  j")
    case False
    thus ?thesis unfolding id by simp
  next
    case True
    hence "?lhs = norm (of_nat (k choose (j - i)) * a ^ (k + i - j))" unfolding id by simp
    also have "  norm (of_nat (k choose (j - i)) :: 'a) * norm (a ^ (k + i - j))"
      by (rule norm_mult_ineq)
    also have "  (max 1 (of_nat k ^ (n - 1))) * norm a ^ (k + i - j)"
    proof (rule mult_mono[OF _ norm_power_ineq _ norm_ge_zero])
      have "k choose (j - i)  max 1 (k ^ (n - 1))" 
        by (rule choose_poly_bound[OF diff])
      hence "norm (of_nat (k choose (j - i)) :: 'a)  of_nat (max 1 (k ^ (n - 1)))"
        unfolding norm_of_nat of_nat_le_iff .
      also have " = max 1 (of_nat k ^ (n - 1))" by (metis max_def of_nat_1 of_nat_le_iff of_nat_power)
      finally show "norm (of_nat (k choose (j - i)) :: 'a)  max 1 (real_of_nat k ^ (n - 1))" .
    qed simp
    also have " = ?rhs" by simp
    finally show ?thesis .
  qed
qed

lemma jordan_block_poly_bound: 
  assumes i: "i < n" and j: "j < n" and a: "norm a = 1"
  shows "norm ((jordan_block n a ^m k) $$ (i,j))  max 1 (of_nat k ^ (n - 1))"
    (is "?lhs  ?rhs")
proof -
  from jordan_block_bound[OF i j, of k, unfolded a]
  show ?thesis by simp
qed


theorem jordan_block_constant_bound: assumes a: "norm a < 1" 
  shows " p.  i j k. i < n  j < n  norm ((jordan_block n a ^m k) $$ (i,j))  p"
proof (cases "a = 0") 
  case True
  show ?thesis
  proof (rule exI[of _ 1], intro allI impI)
    fix i j k
    assume *: "i < n" "j < n"
    {
      assume ij: "i  j"
      have "norm ((of_nat (k choose (j - i)) :: 'a) * 0 ^ (k + i - j))  1" (is "norm ?lhs  1")
      proof (cases "k + i > j")
        case True
        hence "?lhs = 0" by simp
        also have "norm ()  1" by simp
        finally show ?thesis .
      next
        case False
        hence id: "?lhs = (of_nat (k choose (j - i)) :: 'a)" and j: "j - i  k" by auto
        from j have "k choose (j - i) = 0  k choose (j - i) = 1" by (simp add: nat_less_le)
        thus "norm ?lhs  1"
        proof
          assume k: "k choose (j - i) = 0"
          show ?thesis unfolding id k by simp
        next
          assume k: "k choose (j - i) = 1"
          show ?thesis unfolding id unfolding k by simp
        qed
      qed
    }    
    thus "norm ((jordan_block n a ^m k) $$ (i,j))  1" unfolding True
      unfolding jordan_block_pow using * by auto
  qed
next
  case False
  hence na: "norm a > 0" by auto
  define c where "c = inverse (norm a ^ n)"
  define deg where "deg = n - 1"
  have c: "c > 0" unfolding c_def using na by auto
  define b where "b = norm a"
  from a na have "0 < b" "b < 1" unfolding b_def by auto
  from poly_exp_max_constant_bound[OF this, of c deg]
  obtain p where " k. c * b ^ k * max 1 (of_nat k ^ deg)  p" by auto
  show ?thesis
  proof (intro exI[of _ p], intro allI impI)
    fix i j k
    assume ij: "i < n" "j < n"
    from jordan_block_bound[OF this]
    have "norm ((jordan_block n a ^m k) $$ (i, j))
       norm a ^ (k + i - j) * max 1 (real_of_nat k ^ (n - 1))" .
    also have "  c * norm a ^ k * max 1 (real_of_nat k ^ (n - 1))"
    proof (rule mult_right_mono)
      from ij have "i - j  n" by auto
      show "norm a ^ (k + i - j)  c * norm a ^ k"
      proof (rule mult_left_le_imp_le)
        show "0 < norm a ^ n" using na by auto
        let ?lhs = "norm a ^ n * norm a ^ (k + i - j)"
        let ?rhs = "norm a ^ n * (c * norm a ^ k)"
        from ij have ge: "n + (k + i - j)  k" by arith
        have "?lhs = norm a ^ (n + (k + i - j))" by (simp add: power_add)
        also have "  norm a ^ k" using ge a na using less_imp_le power_decreasing by blast
        also have " = ?rhs" unfolding c_def using na by simp
        finally show "?lhs  ?rhs" .
      qed
    qed simp
    also have " = c * b ^ k * max 1 (real_of_nat k ^ deg)" unfolding b_def deg_def ..
    also have "  p" by fact
    finally show "norm ((jordan_block n a ^m k) $$ (i, j))  p" .
  qed
qed

definition norm_bound :: "'a mat  real  bool" where
  "norm_bound A b   i j. i < dim_row A  j < dim_col A  norm (A $$ (i,j))  b"

lemma norm_boundI[intro]:
  assumes " i j. i < dim_row A  j < dim_col A  norm (A $$ (i,j))  b"
  shows "norm_bound A b"
  unfolding norm_bound_def using assms by blast

lemma  jordan_block_constant_bound2:
"p. norm (a :: 'a :: real_normed_field) < 1 
    (i j k. i < n  j < n  norm ((jordan_block n a ^m k) $$ (i, j))  p)"
using jordan_block_constant_bound by auto

lemma jordan_matrix_poly_bound2:
  fixes n_as :: "(nat × 'a) list"
  assumes n_as: " n a. (n,a)  set n_as  n > 0  norm a  1"
  and N: " n a. (n,a)  set n_as  norm a = 1  n  N"
  shows "c1. k. e  elements_mat (jordan_matrix n_as ^m k).
    norm e  c1 + of_nat k ^ (N - 1)"
proof -
  from jordan_matrix_carrier[of n_as] obtain d where
    jm: "jordan_matrix n_as  carrier_mat d d" by blast
  define f where "f = (λn (a::'a) i j k. norm ((jordan_block n a ^m k) $$ (i,j)))"
  let ?g = "λk c1. c1 + of_nat k ^ (N-1)"
  let ?P = "λn (a::'a) i j k c1. f n a i j k  ?g k c1"
  define Q where "Q = (λn (a::'a) k c1. i j. i<n  j<n  ?P n a i j k c1)"
  have " c c' k n a i j. c  c'  ?P n a i j k c  ?P n a i j k c'" by auto  
  hence Q_mono: "n a c c'. c  c'  k. Q n a k c  k. Q n a k c'"
    unfolding Q_def by arith
  { fix n a assume na: "(n,a)  set n_as"
    obtain c where c: "norm a < 1  (i j k. i < n  j < n  f n a i j k  c)"
      apply (rule exE[OF jordan_block_constant_bound2])
      unfolding f_def using Jordan_Normal_Form.jordan_block_constant_bound2
      by metis
    define c1 where "c1 = max 1 c"
    then have "c1  1" "c1  c" by auto
    have "c1. k i j. i < n  j < n  ?P n a i j k c1"
    proof rule+
      fix i j k assume "i < n" "j < n"
      then have "0<n" by auto
      let ?jbs = "map (λ(n,a). jordan_block n a) n_as"
      have sq_jbs: "Ball (set ?jbs) square_mat" by auto
      have "jordan_matrix n_as ^m k = diag_block_mat (map (λA. A ^m k) ?jbs)"
        unfolding jordan_matrix_def using diag_block_pow_mat[OF sq_jbs] by auto
      show "?P n a i j k c1"
      proof (cases "norm a = 1")
        case True {
          have nN:"n-1  N-1" using N[OF na] True by auto
          have "f n a i j k  max 1 (of_nat k ^ (n-1))"
            using Jordan_Normal_Form.jordan_block_poly_bound True i<n j<n
            unfolding f_def by auto
          also have "...  max 1 (of_nat k ^ (N-1))"
            proof (cases "k=0")
              case False then show ?thesis
                by (subst max.mono[OF _ power_increasing[OF nN]], auto)
            qed (simp add: power_eq_if)
          also have "...  max c1 (of_nat k ^ (N-1))" using c11 by auto
          also have "...  c1 + (of_nat k ^ (N-1))" using c11 by auto
          finally show ?thesis by simp
        } next
        case False {
          then have na1: "norm a < 1" using n_as[OF na] 0<n by auto
          hence "f n a i j k  c" using c i<n j<n by auto
          also have "...  c1" using cc1.
          also have "...  c1 + of_nat k ^ (N-1)" by auto
          finally show ?thesis by auto
        }
      qed
    qed
  }
  hence "na. c1. na  set n_as  (k. Q (fst na) (snd na) k c1)"
    unfolding Q_def by auto
  from choice[OF this] obtain c'
    where c': " na k. na  set n_as  Q (fst na) (snd na) k (c' na)" by blast
  define c where "c = max 0 (Max (set (map c' n_as)))"
  { fix n a assume na: "(n,a)  set n_as"
    then have Q: " k. Q n a k (c' (n,a))" using c'[OF na] by auto
    from na have "c' (n,a)  set (map c' n_as)" by auto
    from Max_ge[OF _ this] have "c' (n,a)  c" unfolding c_def by auto
    from Q_mono[OF this Q] have " k. Q n a k c" by blast
  }
  hence Q: "k n a. (n,a)  set n_as  Q n a k c" by auto
  have c0: "c  0" unfolding c_def by simp
  { fix k n a e
    assume na:"(n,a)  set n_as"
    let ?jbk = "jordan_block n a ^m k"
    assume "e  elements_mat ?jbk"
    from elements_matD[OF this] obtain i j
      where "i < n" "j < n" and [simp]: "e = ?jbk $$ (i,j)"
      by (simp only:pow_mat_dim_square[OF jordan_block_carrier],auto)
    hence "norm e  ?g k c" using Q[OF na] unfolding Q_def f_def by simp
  }
  hence norm_jordan:
    "k. (n,a)  set n_as. e  elements_mat (jordan_block n a ^m k).
     norm e  ?g k c" by auto
  { fix k
    let ?jmk = "jordan_matrix n_as ^m k"
    have "dim_row ?jmk = d" "dim_col ?jmk = d"
      using jm by (simp only:pow_mat_dim_square[OF jm])+
    let ?As = "(map (λ(n,a). jordan_block n a ^m k) n_as)"
    have "e. e  elements_mat ?jmk  norm e  ?g k c"
    proof -
      fix e assume e:"e  elements_mat ?jmk"
      obtain i j where ij: "i < d" "j < d" and "e = ?jmk $$ (i,j)"
        using elements_matD[OF e] by (simp only:pow_mat_dim_square[OF jm],auto)
      have "?jmk = diag_block_mat ?As"
        using jordan_matrix_pow[of n_as k] by auto
      hence "elements_mat ?jmk  {0}   (set (map elements_mat ?As))"
        using elements_diag_block_mat[of ?As] by auto
      hence e_mem: "e  {0}   (set (map elements_mat ?As))"
        using e by blast
      show "norm e  ?g k c"
      proof (cases "e = 0")
        case False
          then have "e   (set (map elements_mat ?As))" using e_mem by auto
          then obtain n a
            where "e  elements_mat (jordan_block n a ^m k)"
            and na: "(n,a)  set n_as" by force
          thus ?thesis using norm_jordan na by force
      qed (insert c0, auto)
    qed
  }
  thus ?thesis by auto
qed

lemma norm_bound_bridge:
  "e  elements_mat A. norm e  b  norm_bound A b"
  unfolding norm_bound_def by force

lemma norm_bound_mult: assumes A1: "A1  carrier_mat nr n"
  and A2: "A2  carrier_mat n nc"
  and b1: "norm_bound A1 b1"
  and b2: "norm_bound A2 b2"
  shows "norm_bound (A1 * A2) (b1 * b2 * of_nat n)"
proof 
  let ?A = "A1 * A2"
  let ?n = "of_nat n"
  fix i j
  assume i: "i < dim_row ?A" and j: "j < dim_col ?A"
  define v1 where "v1 = (λ k. row A1 i $ k)"
  define v2 where "v2 = (λ k. col A2 j $ k)"
  from assms(1-2) have dim: "dim_row A1 = nr" "dim_col A2 = nc" "dim_col A1 = n" "dim_row A2 = n" by auto
  {
    fix k
    assume k: "k < n"
    have n: "norm (v1 k)  b1" "norm (v2 k)  b2" 
      using i j k dim v1_def v2_def
      b1[unfolded norm_bound_def, rule_format, of i k] 
      b2[unfolded norm_bound_def, rule_format, of k j] by auto
    have "norm (v1 k * v2 k)  norm (v1 k) * norm (v2 k)" by (rule norm_mult_ineq)
    also have "  b1 * b2" by (rule mult_mono'[OF n], auto)
    finally have "norm (v1 k * v2 k)  b1 * b2" .
  } note bound = this
  have "?A $$ (i,j) = row A1 i  col A2 j" using dim i j by simp
  also have " = ( k = 0 ..< n. v1 k * v2 k)" unfolding scalar_prod_def 
    using dim i j v1_def v2_def by simp
  also have "norm ()  ( k = 0 ..< n. b1 * b2)" 
    by (rule sum_norm_le, insert bound, simp)
  also have " = b1 * b2 * ?n" by simp
  finally show "norm (?A $$ (i,j))  b1 * b2 * ?n" .
qed

lemma norm_bound_max: "norm_bound A (Max {norm (A $$ (i,j)) | i j. i < dim_row A  j < dim_col A})" 
  (is "norm_bound A (Max ?norms)")
proof 
  fix i j
  have fin: "finite ?norms" by (simp add: finite_image_set2)
  assume "i < dim_row A" and "j < dim_col A"     
  hence "norm (A $$ (i,j))  ?norms" by auto
  from Max_ge[OF fin this] show "norm (A $$ (i,j))  Max ?norms" .
qed

lemma jordan_matrix_poly_bound: fixes n_as :: "(nat × 'a)list"
  assumes n_as: " n a. (n,a)  set n_as  n > 0  norm a  1"
  and N: " n a. (n,a)  set n_as  norm a = 1  n  N"
  shows " c1.  k. norm_bound (jordan_matrix n_as ^m k) (c1 + of_nat k ^ (N - 1))" 
  using jordan_matrix_poly_bound2 norm_bound_bridge N n_as
  by metis

lemma jordan_nf_matrix_poly_bound: fixes n_as :: "(nat × 'a)list"
  assumes A: "A  carrier_mat n n"
  and n_as: " n a. (n,a)  set n_as  n > 0  norm a  1"
  and N: " n a. (n,a)  set n_as  norm a = 1  n  N"
  and jnf: "jordan_nf A n_as"
  shows " c1 c2.  k. norm_bound (A ^m k) (c1 + c2 * of_nat k ^ (N - 1))"
proof -
  let ?cp2 = "(n, a)n_as. [:- a, 1:] ^ n"
  let ?J = "jordan_matrix n_as"
  from jnf[unfolded jordan_nf_def]
  have sim: "similar_mat A ?J" by auto
  then obtain P Q where sim_wit: "similar_mat_wit A ?J P Q" unfolding similar_mat_def by auto
  from similar_mat_wit_pow_id[OF this] have pow: " k. A ^m k = P * ?J ^m k * Q" .
  from sim_wit[unfolded similar_mat_wit_def Let_def] A 
  have J: "?J  carrier_mat n n" and P: "P  carrier_mat n n" and Q: "Q  carrier_mat n n"
    unfolding carrier_mat_def by force+
  have "c1.  k. norm_bound (?J ^m k) (c1 + of_nat k ^ (N - 1))"
    by (rule jordan_matrix_poly_bound[OF n_as N])
  then obtain c1 where 
    bound_pow: " k. norm_bound ((?J ^m k)) (c1 + of_nat k ^ (N - 1))" by blast
  obtain bP where bP: "norm_bound P bP" using norm_bound_max[of P] by auto
  obtain bQ where bQ: "norm_bound Q bQ" using norm_bound_max[of Q] by auto
  let ?n = "of_nat n :: real"
  let ?c2 = "bP * ?n * bQ * ?n"
  let ?c1 = "?c2 * c1"
  {
    fix k
    have Jk: "?J ^m k  carrier_mat n n" using J by simp
    from norm_bound_mult[OF mult_carrier_mat[OF P Jk] Q 
      norm_bound_mult[OF P Jk bP bound_pow] bQ, folded pow] 
    have "norm_bound (A ^m k) (?c1 + ?c2 * of_nat k ^ (N - 1))"  (is "norm_bound _ ?exp") 
      by (simp add: field_simps)
  } note main = this
  show ?thesis 
    by (intro exI allI, rule main)
qed
end

context 
  fixes f_ty :: "'a :: field itself"
begin
lemma char_matrix_jordan_block: "char_matrix (jordan_block n a) b = (jordan_block n (a - b))"
  unfolding char_matrix_def jordan_block_def by auto

lemma diag_jordan_block_pow: "diag_mat (jordan_block n (a :: 'a) ^m k) = replicate n (a ^ k)"
  unfolding diag_mat_def jordan_block_pow
  by (intro nth_equalityI, auto)

lemma jordan_block_zero_pow: "(jordan_block n (0 :: 'a)) ^m k = 
  (mat n n (λ (i,j). if j  i  j - i = k then 1 else 0))"
proof -
  {
    fix i j
    assume  *: "j - i  k"
    have "of_nat (k choose (j - i)) * 0 ^ (k + i - j) = (0 :: 'a)"
    proof (cases "k + i - j > 0")
      case True thus ?thesis by (cases "k + i - j", auto)
    next
      case False
      with * have "j - i > k" by auto
      thus ?thesis by (simp add: binomial_eq_0)
    qed
  }
  thus ?thesis unfolding jordan_block_pow by (intro eq_matI, auto)
qed
end

lemma jordan_matrix_concat_diag_block_mat: "jordan_matrix (concat jbs) = diag_block_mat (map jordan_matrix jbs)"
  unfolding jordan_matrix_def[abs_def]
  by (induct jbs, auto simp: diag_block_mat_append Let_def)

lemma jordan_nf_diag_block_mat: assumes Ms: " A jbs. (A,jbs)  set Ms  jordan_nf A jbs"
  shows "jordan_nf (diag_block_mat (map fst Ms)) (concat (map snd Ms))"
proof -
  let ?Ms = "map (λ (A, jbs). (A, jordan_matrix jbs)) Ms"
  have id: "map fst ?Ms = map fst Ms" by auto
  have id2: "map snd ?Ms = map jordan_matrix (map snd Ms)" by auto
  {
    fix A B
    assume "(A,B)  set ?Ms"
    then obtain jbs where mem: "(A,jbs)  set Ms" and B: "B = jordan_matrix jbs" by auto
    from Ms[OF mem] have "similar_mat A B" unfolding B jordan_nf_def by auto
  }
  from similar_diag_mat_block_mat[of ?Ms, OF this, unfolded id id2] Ms
  show ?thesis
    unfolding jordan_nf_def jordan_matrix_concat_diag_block_mat by force
qed  


lemma jordan_nf_char_poly: assumes "jordan_nf A n_as"
  shows "char_poly A = ( (n,a)  n_as. [:- a, 1:] ^ n)"
  unfolding jordan_matrix_char_poly[symmetric]
  by (rule char_poly_similar, insert assms[unfolded jordan_nf_def], auto)

lemma jordan_nf_block_size_order_bound: assumes jnf: "jordan_nf A n_as"
  and mem: "(n,a)  set n_as"
  shows "n  order a (char_poly A)"
proof -
  from jnf[unfolded jordan_nf_def]
  have "similar_mat A (jordan_matrix n_as)" by auto
  from similar_matD[OF this] obtain m where "A  carrier_mat m m" by auto
  from degree_monic_char_poly[OF this] have A: "char_poly A  0" by auto
  from mem obtain as bs where nas: "n_as = as @ (n,a) # bs" 
    by (meson split_list)
  from jordan_nf_char_poly[OF jnf] 
  have cA: "char_poly A = ((n, a)n_as. [:- a, 1:] ^ n)" .
  also have " = [: -a, 1:] ^ n * ((n, a) as @ bs. [:- a, 1:] ^ n)" unfolding nas by auto
  also have "[: -a,1 :] ^ n dvd " unfolding dvd_def by blast
  finally have "[: -a,1 :] ^ n dvd char_poly A" by auto
  from order_max[OF this A] show ?thesis .
qed

lemma similar_mat_jordan_block_smult: fixes A :: "'a :: field mat" 
  assumes "similar_mat A (jordan_block n a)" 
   and k: "k  0" 
  shows "similar_mat (k m A) (jordan_block n (k * a))" 
proof -
  let ?J = "jordan_block n a" 
  let ?Jk = "jordan_block n (k * a)" 
  let ?kJ = "k m jordan_block n a" 
  from k have inv: "k ^ i  0" for i by auto
  let ?A = "mat_diag n (λ i. k^i)" 
  let ?B = "mat_diag n (λ i. inverse (k^i))"
  have "similar_mat_wit ?Jk ?kJ ?A ?B" 
  proof (rule similar_mat_witI)
    show "jordan_block n (k * a) = ?A * ?kJ * ?B"
      by (subst mat_diag_mult_left[of _ _ n], force, subst mat_diag_mult_right[of _ n],
       insert k inv, auto simp: jordan_block_def field_simps intro!: eq_matI)
  qed (auto simp: inv field_simps k)
  hence kJ: "similar_mat ?Jk ?kJ" 
    unfolding similar_mat_def by auto
  have "similar_mat A ?J" by fact
  hence "similar_mat (k m A) (k m ?J)" by (rule similar_mat_smult)
  with kJ show ?thesis
    using similar_mat_sym similar_mat_trans by blast
qed


lemma jordan_matrix_Cons:  "jordan_matrix (Cons (n,a) n_as) = four_block_mat 
  (jordan_block n a)                 (0m n (sum_list (map fst n_as))) 
  (0m (sum_list (map fst n_as)) n)   (jordan_matrix n_as)" 
  unfolding jordan_matrix_def by (simp, simp add: jordan_matrix_def[symmetric])

lemma similar_mat_jordan_matrix_smult:  fixes n_as :: "(nat × 'a :: field) list"
  assumes k: "k  0" 
  shows "similar_mat (k m jordan_matrix n_as) (jordan_matrix (map (λ (n,a). (n, k * a)) n_as))" 
proof (induct n_as)
  case Nil
  show ?case by (auto simp: jordan_matrix_def intro!: similar_mat_refl)
next
  case (Cons na n_as)
  obtain n a where na: "na = (n,a)" by force
  let ?l = "map (λ (n,a). (n, k * a))" 
  let ?n = "sum_list (map fst n_as)" 
  have "k m jordan_matrix (Cons na n_as) = k m four_block_mat 
     (jordan_block n a) (0m n ?n)
     (0m ?n n) (jordan_matrix n_as)" (is "?M = _ m four_block_mat ?A ?B ?C ?D")
    by (simp add: na jordan_matrix_Cons)
  also have " = four_block_mat (k m ?A) ?B ?C (k m ?D)" 
    by (subst smult_four_block_mat, auto)
  finally have jm: "?M = four_block_mat (k m ?A) ?B ?C (k m ?D)" .
  have [simp]: "fst (case x of (n :: nat, a)  (n, k * a)) = fst x" for x by (cases x, auto)
  have jmk: "jordan_matrix (?l (Cons na n_as)) = four_block_mat
     (jordan_block n (k * a)) ?B
     ?C (jordan_matrix (?l n_as))" (is "?kM = four_block_mat ?kA _ _ ?kD")
    by (simp add: na jordan_matrix_Cons o_def)
  show ?case unfolding jmk jm
    by (rule similar_mat_four_block_0_0[OF similar_mat_jordan_block_smult[OF _ k] Cons],
      auto intro!: similar_mat_refl)
qed

lemma jordan_nf_smult: fixes k :: "'a :: field" 
  assumes jn: "jordan_nf A n_as" 
  and k: "k  0" 
  shows "jordan_nf (k m A) (map (λ (n,a). (n, k * a)) n_as)" 
proof -
  let ?l = "map (λ (n,a). (n, k * a))" 
  from jn[unfolded jordan_nf_def] have sim: "similar_mat A (jordan_matrix n_as)" by auto
  from similar_mat_smult[OF this, of k] similar_mat_jordan_matrix_smult[OF k, of n_as]
  have "similar_mat (k m A) (jordan_matrix (map (λ(n, a). (n, k * a)) n_as))" 
    using similar_mat_trans by blast
  with jn show ?thesis unfolding jordan_nf_def by force
qed

lemma jordan_nf_order: assumes "jordan_nf A n_as" 
  shows "order a (char_poly A)  = sum_list (map fst (filter (λ na. snd na = a) n_as))" 
proof - 
  let ?p = "λ n_as. ((n, a)n_as. [:- a, 1:] ^ n)" 
  let ?s = "λ n_as. sum_list (map fst (filter (λ na. snd na = a) n_as))" 
  from jordan_nf_char_poly[OF assms]
  have "order a (char_poly A) = order a (?p n_as)" by simp
  also have " = ?s n_as" 
  proof (induct n_as)
    case (Cons nb n_as)
    obtain n b where nb: "nb = (n,b)" by force
    have "order a (?p (nb # n_as)) = order a ([: -b, 1:] ^ n * ?p n_as)" unfolding nb by simp
    also have " = order a ([: -b, 1:] ^ n) + order a (?p n_as)" 
      by (rule order_mult, auto simp: prod_list_zero_iff)
    also have " = (if a = b then n else 0) + ?s n_as" unfolding Cons order_linear_power by simp
    also have " = ?s (nb # n_as)" unfolding nb by auto
    finally show ?case .
  qed simp
  finally show ?thesis .
qed

subsection ‹Application for Complexity›

lemma factored_char_poly_norm_bound: assumes A: "A  carrier_mat n n"
  and linear_factors: "char_poly A = ( (a :: 'a :: real_normed_field)  as. [:- a, 1:])"
  and jnf_exists: " n_as. jordan_nf A n_as" 
  and le_1: " a. a  set as  norm a  1"
  and le_N: " a. a  set as  norm a = 1  length (filter ((=) a) as)  N"
  shows " c1 c2.  k. norm_bound (A ^m k) (c1 + c2 * of_nat k ^ (N - 1))"
proof -
  from jnf_exists obtain n_as 
    where jnf: "jordan_nf A n_as" by auto
  let ?cp1 = "( a  as. [:- a, 1:])"
  let ?cp2 = "(n, a)n_as. [:- a, 1:] ^ n"
  let ?J = "jordan_matrix n_as"
  from jnf[unfolded jordan_nf_def]
  have sim: "similar_mat A ?J" by auto
  from char_poly_similar[OF sim, unfolded linear_factors jordan_matrix_char_poly]
  have cp: "?cp1 = ?cp2" .
  show ?thesis
  proof (rule jordan_nf_matrix_poly_bound[OF A _ _ jnf])
    fix n a
    assume na: "(n,a)  set n_as"
    then obtain na1 na2 where n_as: "n_as = na1 @ (n,a) # na2"
      unfolding in_set_conv_decomp by auto
    then obtain p where "?cp2 = [: -a, 1 :]^n * p" unfolding n_as by auto
    from cp[unfolded this] have dvd: "[: -a, 1 :] ^ n dvd ?cp1" by auto
    let ?as = "filter ((=) a) as"
    let ?pn = "λ as. aas. [:- a, 1:]"
    let ?p = "λ as. aas. [: a, 1:]"
    have "?pn as = ?p (map uminus as)" by (induct as, auto)
    from poly_linear_exp_linear_factors[OF dvd[unfolded this]] 
    have "n  length (filter ((=) (- a)) (map uminus as))" .
    also have " = length (filter ((=) a) as)" 
      by (induct as, auto)
    finally have filt: "n  length (filter ((=) a) as)" .
    {
      assume "0 < n"
      with filt obtain b bs where "?as = b # bs" by (cases ?as, auto)
      from arg_cong[OF this, of set]
      have "a  set as" by auto 
      from le_1[rule_format, OF this]
      show "norm a  1" .
      note a  set as
    } note mem = this
    {
      assume "norm a = 1" 
      from le_N[OF mem this] filt show "n  N" by (cases n, auto)
    }
  qed
qed

end

Theory Missing_VectorSpace

(*
    Author:      René Thiemann
                 Akihisa Yamada
                 Jose Divasón
    License:     BSD
*)
(* with contributions from Alexander Bentkamp, Universität des Saarlandes *)

section ‹Missing Vector Spaces›

text ‹This theory provides some lemmas which we required when working with vector spaces.›

theory Missing_VectorSpace
imports
  VectorSpace.VectorSpace
  Missing_Ring
  "HOL-Library.Multiset"
begin


(**** The following lemmas that could be moved to HOL/Finite_Set.thy  ****)

(*This a generalization of comp_fun_commute. It is a similar definition, but restricted to a set. 
  When "A = UNIV::'a set", we have "comp_fun_commute_on f A = comp_fun_commute f"*)
locale comp_fun_commute_on = 
  fixes f :: "'a  'a  'a" and A::"'a set"
  assumes comp_fun_commute_restrict: "yA. xA. zA. f y (f x z) = f x (f y z)"
  and f: "f : A  A  A" 
begin

lemma comp_fun_commute_on_UNIV:
  assumes "A = (UNIV :: 'a set)"
  shows "comp_fun_commute f"
  unfolding comp_fun_commute_def 
  using assms comp_fun_commute_restrict f by auto

lemma fun_left_comm: 
  assumes "y  A" and "x  A" and "z  A" shows "f y (f x z) = f x (f y z)"
  using comp_fun_commute_restrict assms by auto

lemma commute_left_comp: 
  assumes "y  A" and "xA" and "zA" and "g  A  A" 
  shows "f y (f x (g z)) = f x (f y (g z))"
  using assms by (auto simp add: Pi_def o_assoc comp_fun_commute_restrict)

lemma fold_graph_finite:
  assumes "fold_graph f z B y"
  shows "finite B"
  using assms by induct simp_all

lemma fold_graph_closed:
  assumes "fold_graph f z B y" and "B  A" and "z  A"
  shows "y  A"
  using assms 
proof (induct set: fold_graph)
  case emptyI
  then show ?case by auto
next
  case (insertI x B y)
  then show ?case using insertI f by auto
qed

lemma fold_graph_insertE_aux:
  "fold_graph f z B y  a  B  zA
   B  A
    y'. y = f a y'  fold_graph f z (B - {a}) y'  y'  A"
proof (induct set: fold_graph)
  case emptyI
  then show ?case by auto
next
  case (insertI x B y)
  show ?case
  proof (cases "x = a")
    case True 
    show ?thesis
    proof (rule exI[of _ y])
      have B: "(insert x B - {a}) = B" using True insertI by auto 
      have "f x y = f a y" by (simp add: True) 
      moreover have "fold_graph f z (insert x B - {a}) y" by (simp add: B insertI)
      moreover have "y  A" using insertI fold_graph_closed[of z B] by auto
      ultimately show " f x y = f a y  fold_graph f z (insert x B - {a}) y  y  A" by simp
    qed
  next
    case False
    then obtain y' where y: "y = f a y'" and y': "fold_graph f z (B - {a}) y'" and y'_in_A: "y'  A"
      using insertI f by auto
    have "f x y = f a (f x y')"
      unfolding y 
    proof (rule fun_left_comm)
      show "x  A" using insertI by auto
      show "a  A" using insertI by auto
      show "y'  A" using y'_in_A by auto
    qed  
    moreover have "fold_graph f z (insert x B - {a}) (f x y')"
      using y' and x  a and x  B
      by (simp add: insert_Diff_if fold_graph.insertI)    
    moreover have "(f x y')  A" using insertI f y'_in_A by auto
    ultimately show ?thesis using y'_in_A
      by auto
  qed
qed
    
lemma fold_graph_insertE:
  assumes "fold_graph f z (insert x B) v" and "x  B" and "insert x B  A" and "zA"
  obtains y where "v = f x y" and "fold_graph f z B y"
  using assms by (auto dest: fold_graph_insertE_aux [OF _ insertI1])

lemma fold_graph_determ: "fold_graph f z B x  fold_graph f z B y   B  A  zA  y = x"
proof (induct arbitrary: y set: fold_graph)
  case emptyI
  then show ?case
    by (meson empty_fold_graphE)
next
  case (insertI x B y v)
  from ‹fold_graph f z (insert x B) v and x  B and ‹insert x B  A and  z  A
  obtain y' where "v = f x y'" and "fold_graph f z B y'"
    by (rule fold_graph_insertE)
  from ‹fold_graph f z B y' and ‹insert x B  A have "y' = y" using insertI by auto    
  with v = f x y' show "v = f x y"
    by simp
qed

lemma fold_equality: "fold_graph f z B y  B  A  z  A  Finite_Set.fold f z B = y"
  by (cases "finite B") 
  (auto simp add: Finite_Set.fold_def intro: fold_graph_determ dest: fold_graph_finite)
    
lemma fold_graph_fold:
  assumes f: "finite B" and BA: "BA" and z: "z  A"
  shows "fold_graph f z B (Finite_Set.fold f z B)"
proof -
   have "x. fold_graph f z B x"
    by (rule finite_imp_fold_graph[OF f])
  moreover note fold_graph_determ
  ultimately have "∃!x. fold_graph f z B x" using f BA z by auto    
  then have "fold_graph f z B (The (fold_graph f z B))"
    by (rule theI')
  with assms show ?thesis
    by (simp add: Finite_Set.fold_def)
qed
  
(*This lemma is a generalization of thm comp_fun_commute.fold_insert*)
lemma fold_insert [simp]:
  assumes "finite B" and "x  B" and BA: "insert x B  A" and z: "z  A"
  shows "Finite_Set.fold f z (insert x B) = f x (Finite_Set.fold f z B)"
  proof (rule fold_equality[OF _ BA z])
  from ‹finite B have "fold_graph f z B (Finite_Set.fold f z B)"
   using BA fold_graph_fold z by auto
  hence "fold_graph f z (insert x B) (f x (Finite_Set.fold f z B))"
    using BA  fold_graph.insertI assms by auto
  then show "fold_graph f z (insert x B) (f x (Finite_Set.fold f z B))"
    by simp
qed
end

(*This lemma is a generalization of thm Finite_Set.fold_cong *)
lemma fold_cong:
  assumes f: "comp_fun_commute_on f A" and g: "comp_fun_commute_on g A"
    and "finite S"
    and cong: "x. x  S  f x = g x"
    and "s = t" and "S = T" 
    and SA: "S  A" and s: "sA"
  shows "Finite_Set.fold f s S = Finite_Set.fold g t T"
proof -
  have "Finite_Set.fold f s S = Finite_Set.fold g s S"
    using ‹finite S cong SA s
  proof (induct S)
    case empty
    then show ?case by simp
  next
    case (insert x F)
    interpret f: comp_fun_commute_on f A by (fact f)
    interpret g: comp_fun_commute_on g A by (fact g)
    show ?case  using insert by auto
  qed
  with assms show ?thesis by simp
qed
                    
context comp_fun_commute_on
begin  

lemma comp_fun_Pi: "(λx. f x ^^ g x)  A  A  A"
proof -    
  have "(f x ^^ g x) y  A" if y: "y  A" and x: "x  A" for x y
    using x y
   proof (induct "g x" arbitrary: g)
     case 0
     then show ?case by auto
   next
     case (Suc n g)
     define h where "h z = g z - 1" for z
     have hyp: "(f x ^^ h x) y  A" 
       using h_def Suc.prems Suc.hyps diff_Suc_1 by metis
     have "g x = Suc (h x)" unfolding h_def
       using Suc.hyps(2) by auto     
     then show ?case using f x hyp unfolding Pi_def by auto
   qed 
  thus ?thesis by (auto simp add: Pi_def)
qed

(*This lemma is a generalization of thm comp_fun_commute.comp_fun_commute_funpow *)
lemma comp_fun_commute_funpow: "comp_fun_commute_on (λx. f x ^^ g x) A"
proof -
  have f: " (f y ^^ g y) ((f x ^^ g x) z) = (f x ^^ g x) ((f y ^^ g y) z)"
    if x: "xA" and y: "y  A" and z: "z  A" for x y z
  proof (cases "x = y")
    case False
    show ?thesis
    proof (induct "g x" arbitrary: g)
      case (Suc n g)
      have hyp1: "(f y ^^ g y) (f x k) = f x ((f y ^^ g y) k)" if k: "k  A" for k
      proof (induct "g y" arbitrary: g)
        case 0
        then show ?case by simp
      next
        case (Suc n g)       
        define h where "h z = g z - 1" for z
        with Suc have "n = h y"
          by simp
        with Suc have hyp: "(f y ^^ h y) (f x k) = f x ((f y ^^ h y) k)"
          by auto
        from Suc h_def have g: "g y = Suc (h y)"
          by simp
        have "((f y ^^ h y) k)  A" using y k comp_fun_Pi[of h] unfolding Pi_def by auto
        then show ?case
          by (simp add: comp_assoc g hyp) (auto simp add: o_assoc comp_fun_commute_restrict x y k)
      qed
      define h where "h a = (if a = x then g x - 1 else g a)" for a
      with Suc have "n = h x"
        by simp
      with Suc have "(f y ^^ h y) ((f x ^^ h x) z) = (f x ^^ h x) ((f y ^^ h y) z)"
        by auto
      with False have Suc2: "(f x ^^ h x) ((f y ^^ g y) z) = (f y ^^ g y) ((f x ^^ h x) z)"
        using h_def by auto      
      from Suc h_def have g: "g x = Suc (h x)"
        by simp
      have "(f x ^^ h x) z A" using comp_fun_Pi[of h] x z unfolding Pi_def by auto
      hence *: "(f y ^^ g y) (f x ((f x ^^ h x) z)) = f x ((f y ^^ g y) ((f x ^^ h x) z))" 
        using hyp1 by auto
      thus ?case using g Suc2 by auto
    qed simp
  qed simp
  thus ?thesis by (auto simp add: comp_fun_commute_on_def comp_fun_Pi o_def)
qed

(*This lemma is a generalization of thm comp_fun_commute.fold_mset_add_mset*)
lemma fold_mset_add_mset: 
  assumes MA: "set_mset M  A" and s: "s  A" and x: "x  A"
  shows "fold_mset f s (add_mset x M) = f x (fold_mset f s M)"
proof -
  interpret mset: comp_fun_commute_on "λy. f y ^^ count M y" A
    by (fact comp_fun_commute_funpow)
  interpret mset_union: comp_fun_commute_on "λy. f y ^^ count (add_mset x M) y" A
    by (fact comp_fun_commute_funpow)
  show ?thesis
  proof (cases "x  set_mset M")
    case False
    then have *: "count (add_mset x M) x = 1"
      by (simp add: not_in_iff)
     have "Finite_Set.fold (λy. f y ^^ count (add_mset x M) y) s (set_mset M) =
      Finite_Set.fold (λy. f y ^^ count M y) s (set_mset M)"
       by (rule fold_cong[of _ A], auto simp add: assms False comp_fun_commute_funpow)
    with False * s MA x show ?thesis
      by (simp add: fold_mset_def del: count_add_mset)
  next
    case True
    let ?f = "(λxa. f xa ^^ count (add_mset x M) xa)"
    let ?f2 = "(λx. f x ^^ count M x)"
    define N where "N = set_mset M - {x}"
    have F: "Finite_Set.fold ?f s (insert x N) = ?f x (Finite_Set.fold ?f s N)" 
      by (rule mset_union.fold_insert, auto simp add: assms N_def)
    have F2: "Finite_Set.fold ?f2 s (insert x N) = ?f2 x (Finite_Set.fold ?f2 s N)"
      by (rule mset.fold_insert, auto simp add: assms N_def)
    from N_def True have *: "set_mset M = insert x N" "x  N" "finite N" by auto
    then have "Finite_Set.fold (λy. f y ^^ count (add_mset x M) y) s N =
      Finite_Set.fold (λy. f y ^^ count M y) s N" 
      using MA N_def s 
      by (auto intro!: fold_cong comp_fun_commute_funpow)
    with * show ?thesis by (simp add: fold_mset_def del: count_add_mset, unfold F F2, auto)      
  qed
qed
end

(**** End of the lemmas that could be moved to HOL/Finite_Set.thy  ****)


lemma Diff_not_in: "a  A - {a}" by auto

context abelian_group begin

lemma finsum_restrict:
  assumes fA: "f : A  carrier G"
      and restr: "restrict f A = restrict g A"
  shows "finsum G f A = finsum G g A"
proof (rule finsum_cong';rule?)
  fix a assume a: "a : A"
  have "f a = restrict f A a" using a by simp
  also have "... = restrict g A a" using restr by simp
  also have "... = g a" using a by simp
  finally show "f a = g a".
  thus "g a : carrier G" using fA a by force
qed

lemma minus_nonzero: "x : carrier G  x  𝟬   x  𝟬"
  using r_neg by force

end

lemma (in ordered_comm_monoid_add) positive_sum:
  assumes X : "finite X"
      and "f : X  { y :: 'a. y  0 }"
  shows "sum f X  0  (sum f X = 0  f ` X  {0})"
  using assms
proof (induct set:finite)
  case (insert x X)
    hence x0: "f x  0" and sum0: "sum f X  0" by auto
    hence "sum f (insert x X)  0" using insert by auto
    moreover
    { assume "sum f (insert x X) = 0"
      hence "f x = 0" "sum f X = 0"
        using sum0 x0 insert add_nonneg_eq_0_iff by auto
    }
    ultimately show ?case using insert by blast
qed auto


lemma insert_union: "insert x X = X  {x}" by auto


context vectorspace begin

lemmas lincomb_insert2 = lincomb_insert[unfolded insert_union[symmetric]]

lemma lincomb_restrict:
  assumes U: "U  carrier V"
      and a: "a : U  carrier K"
      and restr: "restrict a U = restrict b U"
  shows "lincomb a U = lincomb b U"
proof -
  let ?f = "λa u. a u V u"
  have fa: "?f a : U  carrier V" using a U by auto
  have "restrict (?f a) U = restrict (?f b) U"
  proof
    fix u
    have "u : U  a u = b u" using restr unfolding restrict_def by metis
    thus "restrict (?f a) U u = restrict (?f b) U u" by auto
  qed
  thus ?thesis
    unfolding lincomb_def using finsum_restrict[OF fa] by auto
qed

lemma lindep_span:
  assumes U: "U  carrier V" and finU: "finite U"
  shows "lin_dep U = (u  U. u  span (U - {u}))" (is "?l = ?r")
proof
  assume l: "?l" show "?r"
  proof -
    from l[unfolded lin_dep_def]
    obtain A a u
    where finA: "finite A"
      and AU: "A  U"
      and aA: "a : A  carrier K"
      and aA0: "lincomb a A = zero V"
      and uA: "u:A"
      and au0: "a u  zero K" by auto
    define a' where "a' = (λv. (if v : A then a v else zero K))"
    have a'U: "a' : U  carrier K" unfolding a'_def using aA by auto
    have uU: "u : U" using uA AU by auto
    have a'u0: "a' u  zero K" unfolding a'_def using au0 uA by auto
    define B where "B = U - A"
    have B: "B  carrier V" unfolding B_def using U by auto
    have UAB: "U = A  B" unfolding B_def using AU by auto
    have finB: "finite B" using finU B_def by auto
    have AB: "A  B = {}" unfolding B_def by auto
    let ?f = "λv. a v V v"
    have fA: "?f : A  carrier V" unfolding a'_def using aA AU U by auto
    let ?f' = "λv. a' v V v"
    have "restrict ?f A = restrict ?f' A" unfolding a'_def by auto
    hence finsum: "finsum V ?f' A = finsum V ?f A"
      using finsum_restrict[OF fA] by simp
    have f'A: "?f' : A  carrier V"
    proof
      fix x assume xA: "x  A"
      show "?f' x : carrier V" unfolding a'_def using aA xA AU U by auto
    qed
    have f'B: "?f' : B  carrier V"
    proof
      fix x assume xB: "x : B"
      have "x  A" using a'U xB unfolding B_def by auto
      thus "?f' x : carrier V"using xB B unfolding a'_def by auto
    qed
    have sumB0: "finsum V ?f' B = zero V"
    proof -
      { fix B'
        have "finite B'  B'  B  finsum V ?f' B' = zero V"
        proof(induct set:finite)
          case (insert b B')
            have finB': "finite B'" and bB': "b  B'" using insert by auto
            have f'B': "?f' : B'  carrier V" using f'B insert by auto
            have bA: "b  A" using insert unfolding B_def by auto
            have b: "b : carrier V" using insert B by auto
            have foo: "a' b V b  carrier V" unfolding a'_def using bA b by auto
            have IH: "finsum V ?f' B' = zero V" using insert by auto
            show ?case
              unfolding finsum_insert[OF finB' bB' f'B' foo]
              using IH a'_def bA b by auto
         qed auto
      }
      thus ?thesis using finB by auto
    qed
    have a'A0: "lincomb a' U = zero V"
      unfolding UAB
      unfolding lincomb_def
      unfolding finsum_Un_disjoint[OF finA finB AB f'A f'B]
      unfolding finsum
      unfolding aA0[unfolded lincomb_def]
      unfolding sumB0 by simp
    have uU: "u:U" using uA AU by auto
    moreover have "u : span (U - {u})"
      using lincomb_isolate(2)[OF finU U a'U uU a'u0 a'A0].
    ultimately show ?r by auto
  qed
  next assume r: "?r" show "?l"
  proof -
    from r obtain u where uU: "u : U" and uspan: "u : span (U-{u})" by auto
    hence u: "u : carrier V" using U by auto
    have finUu: "finite (U-{u})" using finU by auto
    have Uu: "U-{u}  carrier V" using U by auto
    obtain a
      where ulin: "u = lincomb a (U-{u})"
        and a: "a : U-{u}  carrier K"
      using uspan unfolding finite_span[OF finUu Uu] by auto
    show "?l" unfolding lin_dep_def
    proof(intro exI conjI)
      let ?a = "λv. if v = u then K one K else a v"
      show "?a : U  carrier K" using a by auto
      hence a': "?a : U-{u}{u}  carrier K" by auto
      have "U = U-{u}{u}" using uU by auto
      also have "lincomb ?a ... = ?a u V u V lincomb ?a (U-{u})"
        unfolding lincomb_insert[OF finUu Uu a' Diff_not_in u] by auto
      also have "restrict a (U-{u}) = restrict ?a (U-{u})" by auto
        hence "lincomb ?a (U-{u}) = lincomb a (U-{u})"
          using lincomb_restrict[OF Uu a] by auto
      also have "?a u V u = V u" using smult_minus_1[OF u] by simp
      also have "lincomb a (U-{u}) = u" using ulin..
      also have "V u V u = zero V" using l_neg[OF u].
      finally show "lincomb ?a U = zero V" by auto
    qed (insert uU finU, auto)
  qed
qed

lemma not_lindepD:
  assumes "~ lin_dep S"
      and "finite A" "A  S" "f : A  carrier K" "lincomb f A = zero V"
    shows "f : A  {zero K}"
  using assms unfolding lin_dep_def by blast


lemma span_mem:
  assumes E: "E  carrier V" and uE: "u : E" shows "u : span E"
  unfolding span_def
proof (rule,intro exI conjI)
  show "u = lincomb (λ_. one K) {u}" unfolding lincomb_def using uE E by auto
  show "{u}  E" using uE by auto
qed auto

lemma lincomb_distrib:
  assumes U: "U  carrier V"
      and a: "a : U  carrier K"
      and c: "c : carrier K"
  shows "c V lincomb a U = lincomb (λu. c K a u) U"
    (is "_ = lincomb ?b U")
  using U a
proof (induct U rule: infinite_finite_induct)
  case empty show ?case unfolding lincomb_def using c by simp next
  case (insert u U)
    then have U: "U  carrier V"
          and u: "u : carrier V"
          and a: "a : insert u U  carrier K"
          and finU: "finite U" by auto
    hence aU: "a : U  carrier K" by auto
    have b: "?b : insert u U  carrier K" using a c by auto
    show ?case
      unfolding lincomb_insert2[OF finU U a uU u]
      unfolding lincomb_insert2[OF finU U b uU u]
      using insert U aU c u smult_r_distr smult_assoc1 by auto next
  case (infinite U)
    thus ?case unfolding lincomb_def using assms by simp
qed

lemma span_swap:
  assumes finE[simp]: "finite E"
      and E[simp]: "E  carrier V"
      and u[simp]: "u : carrier V"
      and usE: "u  span E"
      and v[simp]: "v : carrier V"
      and usEv: "u : span (insert v E)"
  shows "span (insert u E)  span (insert v E)" (is "?L  ?R")
proof -
  have Eu[simp]: "insert u E  carrier V" by auto
  have Ev[simp]: "insert v E  carrier V" by auto
  have finEu: "finite (insert u E)" and finEv: "finite (insert v E)"
    using finE by auto
  have uE: "u  E" using usE span_mem by force
  have vE: "v  E"
  proof
    assume "v : E"
    hence EvE: "insert v E = E" using insert_absorb by auto
    hence "u : span E" using usEv by auto
    thus "False" using usE by auto
  qed
  obtain ua
    where ua[simp]: "ua : (insert v E)  carrier K"
      and uua: "u = lincomb ua (insert v E)"
    using usEv finite_span[OF finEv Ev]  by auto
  hence uaE[simp]: "ua : E  carrier K" and [simp]: "ua v : carrier K"
    by auto

  show "?L  ?R"
  proof
    fix x assume "x : ?L"
    then obtain xa
    where xa: "xa : insert u E  carrier K"
      and xxa: "x = lincomb xa (insert u E)"
      unfolding finite_span[OF finEu Eu] by auto
    hence xaE[simp]: "xa : E  carrier K" and xau[simp]: "xa u : carrier K" by auto
    show "x : span (insert v E)"
      unfolding finite_span[OF finEv Ev]
    proof (rule,intro exI conjI)
      define a where "a = (λe. xa u K ua e)"
      define a' where "a' = (λe. a e K xa e)"
      define a'' where "a'' = (λe. if e = v then xa u K ua v else a' e)"
      have aE: "a : E  carrier K" unfolding a_def using xau uaE u by blast
      hence a'E: "a' : E  carrier K" unfolding a'_def using xaE by blast
      thus a'': "a'' : insert v E  carrier K" unfolding a''_def by auto
      have restr: "restrict a' E = restrict a'' E"
        unfolding a''_def using vE by auto
      have "x = xa u V u V lincomb xa E"
        using xxa lincomb_insert2 finE E xa uE u by auto
      also have
        "xa u V u = xa u V lincomb ua (insert v E)"
        using uua by auto
      also have "lincomb ua (insert v E) = ua v V v V lincomb ua E"
        using lincomb_insert2 finE E ua vE v by auto
      also have "xa u V ... = xa u V (ua v V v) V xa u V lincomb ua E"
        using smult_r_distr by auto
      also have "xa u V lincomb ua E = lincomb a E"
        unfolding a_def using lincomb_distrib[OF E] by auto
      finally have "x = xa u V (ua v V v) V (lincomb a E V lincomb xa E)"
        using a_assoc xau v aE xaE by auto
      also have "lincomb a E V lincomb xa E = lincomb a' E"
        unfolding a'_def using lincomb_sum[OF finE E aE xaE]..
      also have "... = lincomb a'' E"
        using lincomb_restrict[OF E a'E ] restr by auto
      finally have "x = ((xa u K ua v) V v) V lincomb a'' E"
        using smult_assoc1 by auto
      also have "xa u K ua v = a'' v" unfolding a''_def by simp
      also have "(a'' v V v) V lincomb a'' E = lincomb a'' (insert v E)"
        using lincomb_insert2[OF finE E a'' vE] by auto
      finally show "x = lincomb a'' (insert v E)".
    qed
  qed
qed

lemma basis_swap:
  assumes finE[simp]: "finite E"
      and u[simp]: "u : carrier V"
      and uE[simp]: "u  E"
      and b: "basis (insert u E)"
      and v[simp]: "v : carrier V"
      and uEv: "u : span (insert v E)"
  shows "basis (insert v E)"
  unfolding basis_def
proof (intro conjI)
  have Eu[simp]: "insert u E  carrier V"
   and spanEu: "carrier V = span (insert u E)"
   and indEu: "~ lin_dep (insert u E)"
    using b[unfolded basis_def] by auto
  hence E[simp]: "E  carrier V" by auto
  thus Ev[simp]: "insert v E  carrier V" using v by auto
  have finEu: "finite (insert u E)"
   and finEv: "finite (insert v E)" using finE by auto
  have usE: "u  span E"
  proof
    assume "u : span E"
    hence "u : span (insert u E - {u})" using uE by auto
    moreover have "u : insert u E" by auto
    ultimately have "lin_dep (insert u E)"
      unfolding lindep_span[OF Eu finEu] by auto
    thus "False" using b unfolding basis_def by auto
  qed

  obtain ua
    where ua[simp]: "ua : insert v E  carrier K"
      and uua: "u = lincomb ua (insert v E)"
    using uEv finite_span[OF finEv Ev]  by auto
  hence uaE[simp]: "ua : E  carrier K"
    and uav[simp]: "ua v : carrier K"
    by auto

  have vE: "v  E"
  proof
    assume "v : E"
    hence EvE: "insert v E = E" using insert_absorb by auto
    hence "u : span E" using uEv by auto
    thus "False" using usE by auto
  qed
  have vsE: "v  span E"
  proof
    assume "v : span E"
    then obtain va
      where va[simp]: "va : E  carrier K"
        and vva: "v = lincomb va E"
      using finite_span[OF finE E] by auto
    define va' where "va' = (λe. ua v K va e)"
    define va'' where "va'' = (λe. va' e K ua e)"
    have va'[simp]: "va' : E  carrier K"
      unfolding va'_def using uav va by blast
    hence va''[simp]: "va'' : E  carrier K"
      unfolding va''_def using ua by blast
    from uua
    have "u = ua v V v V lincomb ua E"
      using lincomb_insert2 vE by auto
    also have "ua v V v = ua v V (lincomb va E)"
      using vva by auto
    also have "... = lincomb va' E"
      unfolding va'_def using lincomb_distrib by auto
    finally have "u = lincomb va'' E"
      unfolding va''_def
      using lincomb_sum[OF finE E] by auto
    hence "u : span E" using finite_span[OF finE E] va'' by blast
    hence "lin_dep (insert u E)" using lindep_span by simp
    then show False using indEu by auto
  qed

  have indE: "~ lin_dep E" using indEu subset_li_is_li by auto

  show "~ lin_dep (insert v E)" using lin_dep_iff_in_span[OF E indE v vE] vsE by auto

  show "span (insert v E) = carrier V" (is "?L = ?R")
  proof (rule)
    show "?L  ?R"
    proof
      have finEv: "finite (insert v E)" using finE by auto
      fix x assume "x : ?L"
      then obtain a
        where a: "a : insert v E  carrier K"
          and x: "x = lincomb a (insert v E)"
        unfolding finite_span[OF finEv Ev] by auto
      from a have av: "a v : carrier K" by auto
      from a have a2: "a : E  carrier K" by auto
      show "x : ?R"
        unfolding x
        unfolding lincomb_insert2[OF finE E a vE v]
        using lincomb_closed[OF E a2] av v
        by auto
    qed
    show "?R  ?L"
      using span_swap[OF finE E u usE v uEv] spanEu by auto
  qed
qed

lemma span_empty: "span {} = {zero V}"
  unfolding finite_span[OF finite.emptyI empty_subsetI]
  unfolding lincomb_def by simp

lemma span_self: assumes [simp]: "v : carrier V" shows "v : span {v}"
proof -
  have "v = lincomb (λx. one K) {v}" unfolding lincomb_def by auto
  thus ?thesis using finite_span by auto
qed

lemma span_zero: "zero V : span U" unfolding span_def lincomb_def by force

definition emb where "emb f D x = (if x : D then f x else zero K)"

lemma emb_carrier[simp]: "f : D  R  emb f D : D  R"
  apply rule unfolding emb_def by auto

lemma emb_restrict: "restrict (emb f D) D = restrict f D"
  apply rule unfolding restrict_def emb_def by auto

lemma emb_zero: "emb f D : X - D  { zero K }"
  apply rule unfolding emb_def by auto

lemma lincomb_clean:
  assumes A: "A  carrier V"
    and Z: "Z  carrier V"
    and finA: "finite A"
    and finZ: "finite Z"
    and aA: "a : A  carrier K"
    and aZ: "a : Z  { zero K }"
  shows "lincomb a (A  Z) = lincomb a A"
  using finZ Z aZ
proof(induct set:finite)
  case empty thus ?case by simp next
  case (insert z Z) show ?case
    proof (cases "z : A")
      case True hence "A  insert z Z = A  Z" by auto
        thus ?thesis using insert by simp next
      case False
        have finAZ: "finite (A  Z)" using finA insert by simp
        have AZ: "A  Z  carrier V" using A insert by simp
        have a: "a : insert z (A  Z)  carrier K" using insert aA by force
        have "a z = zero K" using insert by auto
        also have "... V z = zero V" using insert by auto
        also have "... V lincomb a (A  Z) = lincomb a (A  Z)"
          using insert AZ aA by auto
        also have "... = lincomb a A" using insert by simp
        finally have "a z V z V lincomb a (A  Z) = lincomb a A".
        thus ?thesis
          using lincomb_insert2[OF finAZ AZ a] False insert by auto
    qed
qed

lemma span_add1:
  assumes U: "U  carrier V" and v: "v : span U" and w: "w : span U"
  shows "v V w : span U"
proof -
  from v obtain a A
    where finA: "finite A"
      and va: "lincomb a A = v"
      and AU: "A  U"
      and a: "a : A  carrier K"
    unfolding span_def by auto
  hence A: "A  carrier V" using U by auto
  from w obtain b B
    where finB: "finite B"
      and wb: "lincomb b B = w"
      and BU: "B  U"
      and b: "b : B  carrier K"
    unfolding span_def by auto
  hence B: "B  carrier V" using U by auto

  have B_A: "B - A  carrier V" and A_B: "A - B  carrier V" using A B by auto

  have a': "emb a A : A  B  carrier K"
    apply (rule Pi_I) unfolding emb_def using a by auto
  hence a'A: "emb a A : A  carrier K" by auto
  have a'Z: "emb a A : B - A  { zero K }"
    apply (rule Pi_I) unfolding emb_def using a by auto

  have b': "emb b B : A  B  carrier K"
    apply (rule Pi_I) unfolding emb_def using b by auto
  hence b'B: "emb b B : B  carrier K" by auto
  have b'Z: "emb b B : A - B  { zero K }"
    apply (rule Pi_I) unfolding emb_def using b by auto

  show ?thesis
    unfolding span_def
    proof (rule,intro exI conjI)
      let ?v = "lincomb (emb a A) (A  B)"
      let ?w = "lincomb (emb b B) (A  B)"
      let ?ab = "λu. (emb a A) u K (emb b B) u"
      show finAB: "finite (A  B)" using finA finB by auto
      show "A  B  U" using AU BU by auto
      show "?ab : A  B  carrier K" using a' b' by auto
      have "v = ?v"
        using va lincomb_restrict[OF A a emb_restrict[symmetric]]
        using lincomb_clean[OF A B_A] a'A a'Z finA finB by simp
      moreover have "w = ?w"
        apply (subst Un_commute)
        using wb lincomb_restrict[OF B b emb_restrict[symmetric]]
        using lincomb_clean[OF B A_B] finA finB b'B b'Z by simp
      ultimately show "v V w = lincomb ?ab (A  B)"
        using lincomb_sum[OF finAB] A B a' b' by simp
    qed
qed

lemma span_neg:
  assumes U: "U  carrier V" and vU: "v : span U"
  shows "V v : span U"
proof -
  have v: "v : carrier V" using vU U unfolding span_def by auto
  from vU[unfolded span_def]
  obtain a A
    where finA: "finite A"
      and AU: "A  U"
      and a: "a  A  carrier K"
      and va: "v = lincomb a A" by auto
  hence A: "A  carrier V" using U by simp
  let ?a = "λu. K one K K a u"

  have "V v = K one K V v" using smult_minus_1_back[OF v].
  also have "... = K one K V lincomb a A" using va by simp
  finally have main: "V v = lincomb ?a A"
    unfolding lincomb_distrib[OF A a R.a_inv_closed[OF R.one_closed]] by auto
  show ?thesis
    unfolding span_def
    apply rule
    using main a finA AU by force
qed

lemma span_closed[simp]: "U  carrier V  v : span U  v : carrier V"
  unfolding span_def by auto

lemma span_add:
  assumes U: "U  carrier V" and vU: "v : span U" and w[simp]: "w : carrier V"
  shows "w : span U  v V w : span U" (is "?L  ?R")
proof
  show "?L  ?R" using span_add1[OF U vU] by auto
  assume R: "?R" show "?L"
  proof -
    have v[simp]: "v : carrier V" using vU U by simp
    have "w = zero V V w" using M.l_zero by auto
    also have "... = V v V v V w" using M.l_neg by auto
    also have "... = V v V (v V w)"
      using M.l_zero M.a_assoc M.a_closed by auto
    also have "... : span U" using span_neg[OF U vU] span_add1[OF U] R by auto
    finally show ?thesis.
  qed
qed


lemma lincomb_union:
  assumes U: "U  carrier V"
      and U'[simp]: "U'  carrier V"
      and disj: "U  U' = {}"
      and finU: "finite U"
      and finU': "finite U'"
      and a: "a : U  U'  carrier K"
    shows "lincomb a (U  U') = lincomb a U V lincomb a U'"
  using finU U disj a
proof (induct set:finite)
  case empty thus ?case by (subst(2) lincomb_def, simp) next
  case (insert u U) thus ?case
    unfolding Un_insert_left
    using lincomb_insert2 finU' insert a_assoc by auto
qed

lemma span_union1:
  assumes U: "U  carrier V" and U': "U'  carrier V" and UU': "span U = span U'"
      and W: "W  carrier V" and W': "W'  carrier V" and WW': "span W = span W'"
  shows "span (U  W)  span (U'  W')" (is "?L  ?R")
proof
  fix x assume "x : ?L"
  then obtain a A
    where finA: "finite A"
      and AUW: "A  U  W"
      and x: "x = lincomb a A"
      and a: "a : A  carrier K"
    unfolding span_def by auto
  let ?AU = "A  U" and ?AW = "A  W - U"
  have AU: "?AU  carrier V" using U by auto
  have AW: "?AW  carrier V" using W by auto
  have disj: "?AU  ?AW = {}" by auto
  have U'W': "U'  W'  carrier V" using U' W' by auto

  have "?AU  ?AW = A" using AUW by auto
  hence "x = lincomb a (?AU  ?AW)" using x by auto
  hence "x = lincomb a ?AU V lincomb a ?AW"
    using lincomb_union[OF AU AW disj] finA a by auto
  moreover
    have "lincomb a ?AU : span U" and "lincomb a ?AW : span W"
      unfolding span_def using AU a finA by auto
    hence "lincomb a ?AU : span U'" and "lincomb a ?AW : span W'"
      using UU' WW' by auto
    hence "lincomb a ?AU : ?R" and "lincomb a ?AW : ?R"
      using span_is_monotone[OF Un_upper1, of U']
      using span_is_monotone[OF Un_upper2, of W'] by auto
  ultimately
    show "x : ?R" using span_add1[OF U'W'] by auto
qed

lemma span_Un:
  assumes U: "U  carrier V" and U': "U'  carrier V" and UU': "span U = span U'"
      and W: "W  carrier V" and W': "W'  carrier V" and WW': "span W = span W'"
  shows "span (U  W) = span (U'  W')" (is "?L = ?R")
  using span_union1[OF assms]
  using span_union1[OF U' U UU'[symmetric] W' W WW'[symmetric]]
  by auto

lemma lincomb_zero:
  assumes U: "U  carrier V" and a: "a : U  {zero K}"
  shows "lincomb a U = zero V"
  using U a
proof (induct U rule: infinite_finite_induct)
  case empty show ?case unfolding lincomb_def by auto next
  case (insert u U)
    hence "a  insert u U  carrier K" using zero_closed by force
    thus ?case using insert by (subst lincomb_insert2; auto)
qed (auto simp: lincomb_def)

end

context module
begin

lemma lincomb_empty[simp]: "lincomb a {} = 𝟬M"
  unfolding lincomb_def by auto

end

context linear_map
begin

interpretation Ker: vectorspace K "(V.vs kerT)"
  using kerT_is_subspace
  using V.subspace_is_vs by blast

interpretation im: vectorspace K "(W.vs imT)"
  using imT_is_subspace
  using W.subspace_is_vs by blast

lemma inj_imp_Ker0:
assumes "inj_on T (carrier V)"
shows "carrier (V.vs kerT) = {𝟬V}"
  unfolding mod_hom.ker_def
  using assms inj_on_contraD by fastforce

lemma Ke0_imp_inj:
assumes c: "carrier (V.vs kerT) = {𝟬V}"
shows "inj_on T (carrier V)"
proof (auto simp add: inj_on_def)
  fix x y
  assume x: "x  carrier V" and y: "y  carrier V"
  and Tx_Ty: "T x = T y" 
  hence "T x W T y = 𝟬W" using W.module.M.minus_other_side by auto
  hence "T (x V y) = 𝟬W" by (simp add: x y)
  hence "x V y  carrier (V.vs kerT)" by (simp add: mod_hom.ker_def x y) 
  hence "x V y = 𝟬V" using c by fast
  thus "x = y" by (simp add: x y)
qed

corollary Ke0_iff_inj: "inj_on T (carrier V) = (carrier (V.vs kerT) = {𝟬V})"
using inj_imp_Ker0 Ke0_imp_inj by auto

lemma inj_imp_dim_ker0:
assumes "inj_on T (carrier V)"
shows "vectorspace.dim K (V.vs kerT) = 0"
proof (unfold Ker.dim_def, rule Least_eq_0, rule exI[of _ "{}"])
    have Ker_rw: "carrier (V.vs kerT) = {𝟬V}" 
      unfolding mod_hom.ker_def
      using assms inj_on_contraD by fastforce
    have "finite {}" by simp 
    moreover have "card {} = 0" by simp
    moreover have "{}  carrier (V.vs kerT)" by simp
    moreover have "Ker.gen_set {}" unfolding Ker_rw by (simp add: Ker.span_empty)
    ultimately show "finite {}  card {} = 0  {}  carrier (V.vs kerT)  Ker.gen_set {}" by simp
qed


lemma surj_imp_imT_carrier:
assumes surj: "T` (carrier V) = carrier W"
shows "(imT) = carrier W"
by (simp add: surj im_def) 

lemma dim_eq:
assumes fin_dim_V: "V.fin_dim"
and i: "inj_on T (carrier V)" and surj: "T` (carrier V) = carrier W"
shows "V.dim = W.dim"
proof -
  have dim0: "vectorspace.dim K (V.vs kerT) = 0" 
    by (rule inj_imp_dim_ker0[OF i])
  have imT_W: "(imT) = carrier W"
    by (rule surj_imp_imT_carrier[OF surj])
  have rnt: "vectorspace.dim K (W.vs imT) + vectorspace.dim K (V.vs kerT) = V.dim"
    by (rule rank_nullity[OF fin_dim_V])
  hence "V.dim = vectorspace.dim K (W.vs imT)" using dim0 by auto
  also have "...  = W.dim" using imT_W by auto
  finally show ?thesis using fin_dim_V by auto
qed       


lemma lincomb_linear_image:
assumes inj_T: "inj_on T (carrier V)"
assumes A_in_V: "A  carrier V" and a: "a  (T`A)  carrier K"
assumes f: "finite A"
shows "W.module.lincomb a (T`A) = T (V.module.lincomb (a  T) A)"
using f using A_in_V a
proof (induct A)
  case empty thus ?case by auto
next
  case (insert x A)
  have T_insert_rw: "T ` insert x A = insert (T x) (T` A)" by simp
  have "W.module.lincomb a (T ` insert x A) = W.module.lincomb a (insert (T x) (T` A))" 
    unfolding T_insert_rw ..
  also have "... =  a (T x) W (T x) W W.module.lincomb a (T` A)"
  proof (rule W.lincomb_insert2)
    show "finite (T ` A)" by (simp add: insert.hyps(1))
    show "T ` A  carrier W" using insert.prems(1) by auto
    show "a  insert (T x) (T ` A)  carrier K" 
      using insert.prems(2) by blast
    show "T x  T ` A" 
      by (meson inj_T inj_on_image_mem_iff insert.hyps(2) insert.prems(1) insert_subset)
    show "T x  carrier W" using insert.prems(1) by blast
  qed
  also have "... = a (T x) W (T x) W (T (V.module.lincomb (a  T) A))"
    using insert.hyps(3) insert.prems(1) insert.prems(2) by fastforce 
  also have "... = T (a (T x) V x) W (T (V.module.lincomb (a  T) A))"
    using insert.prems(1) insert.prems(2) by auto
  also have "... = T ((a (T x) V x) V (V.module.lincomb (a  T) A))"
  proof (rule T_add[symmetric])
    show "a (T x) V x  carrier V" using insert.prems(1) insert.prems(2) by auto
    show "V.module.lincomb (a  T) A  carrier V"
    proof (rule V.module.lincomb_closed)
      show "A  carrier V" using insert.prems(1) by blast
      show "a  T  A  carrier K" using coeff_in_ring insert.prems(2) by auto
    qed
  qed
  also have "... = T (V.module.lincomb (a  T) (insert x A))"
  proof (rule arg_cong[of _ _ T])
    have "a  T  insert x A  carrier K"
      using comp_def insert.prems(2) by auto
    then show "a (T x) V x V V.module.lincomb (a  T) A 
      = V.module.lincomb (a  T) (insert x A)"
      using V.lincomb_insert2 insert.hyps(1) insert.hyps(2) insert.prems(1) by force
  qed
  finally show ?case .
qed
   


lemma surj_fin_dim:  
  assumes fd: "V.fin_dim" and surj: "T` (carrier V) = carrier W"
  shows image_fin_dim: "W.fin_dim"
    using rank_nullity_main(2)[OF fd surj] .

lemma linear_inj_image_is_basis:
assumes inj_T: "inj_on T (carrier V)" and surj: "T` (carrier V) = carrier W"
and basis_B: "V.basis B"
and fin_dim_V: "V.fin_dim"
shows "W.basis (T`B)"
proof (rule W.dim_li_is_basis)
  have lm: "linear_map K V W T" by intro_locales
  have inj_TB: "inj_on T B"
    by (meson basis_B inj_T subset_inj_on V.basis_def)
  show "W.fin_dim" by (rule surj_fin_dim[OF fin_dim_V surj])  
  show "finite (T ` B)"
  proof (rule finite_imageI, rule V.fin[OF fin_dim_V])
    show "V.module.lin_indpt B" using basis_B unfolding V.basis_def by auto
    show "B  carrier V" using basis_B unfolding V.basis_def by auto
  qed
  show "T ` B  carrier W" using basis_B unfolding V.basis_def by auto
  show "W.dim  card (T ` B)"
  proof -
    have d: "V.dim = W.dim" by (rule dim_eq[OF fin_dim_V inj_T surj])
    have "card (T` B) = card B" by (simp add: card_image inj_TB)
    also have "... = V.dim" using basis_B fin_dim_V V.basis_def V.dim_basis V.fin by auto
    finally show ?thesis using d by simp
  qed
  show "W.module.lin_indpt (T ` B)"
  proof (rule W.module.finite_lin_indpt2)
     show fin_TB: "finite (T ` B)" by fact
     show TB_W: "T ` B  carrier W" by fact
     fix a assume a: "a  T ` B  carrier K" and lc_a: "W.module.lincomb a (T ` B) = 𝟬W"
     show "vT ` B. a v = 𝟬K" 
     proof (rule ballI)
      fix v assume v: "v  T ` B"
      have "W.module.lincomb a (T ` B) = T (V.module.lincomb (a  T) B)"
      proof (rule lincomb_linear_image[OF inj_T])
        show "B  carrier V" using V.vectorspace_axioms basis_B vectorspace.basis_def by blast
        show "a  T ` B  carrier K" by (simp add: a)
        show "finite B" using fin_TB finite_image_iff inj_TB by blast
      qed
      hence T_lincomb: "T (V.module.lincomb (a  T) B) = 𝟬W" using lc_a by simp
      have lincomb_0: "V.module.lincomb (a  T) B = 𝟬V"
      proof -
        have "a  T  B  carrier K"
          using a by auto
        then show ?thesis
          by (metis V.module.M.zero_closed V.module.lincomb_closed 
            T_lincomb basis_B f0_is_0 inj_T inj_onD  V.basis_def)
      qed 
      have "(a  T)  B  {𝟬K}" 
      proof (rule V.not_lindepD[OF _ _ _ _ lincomb_0])
        show "V.module.lin_indpt B" using V.basis_def basis_B by blast
        show "finite B" using fin_TB finite_image_iff inj_TB by auto
        show "B  B" by auto
        show "a  T  B  carrier K" using a by auto
      qed
      thus "a v = 𝟬K" using v by auto
    qed
  qed
qed

end

lemma (in vectorspace) dim1I:
assumes "gen_set {v}"
assumes "v  𝟬V" "v  carrier V"
shows "dim = 1"
proof -
  have "basis {v}" by (metis assms(1) assms(2) assms(3) basis_def empty_iff empty_subsetI
   finite.emptyI finite_lin_indpt2 insert_iff insert_subset insert_union lin_dep_iff_in_span
   span_empty)
  then show ?thesis using dim_basis by force
qed

lemma (in vectorspace) dim0I:
assumes "gen_set {𝟬V}"
shows "dim = 0"
proof -
  have "basis {}" unfolding basis_def using already_in_span assms finite_lin_indpt2 span_zero by auto
  then show ?thesis using dim_basis by force
qed

lemma (in vectorspace) dim_le1I:
assumes "gen_set {v}"
assumes "v  carrier V"
shows "dim  1"
by (metis One_nat_def assms(1) assms(2) bot.extremum card.empty card.insert empty_iff finite.intros(1)
finite.intros(2) insert_subset vectorspace.gen_ge_dim vectorspace_axioms)

definition find_indices where "find_indices x xs  [i  [0..<length xs]. xs!i = x]"

lemma find_indices_Nil [simp]:
  "find_indices x [] = []"
  by (simp add: find_indices_def)

lemma find_indices_Cons:
  "find_indices x (y#ys) = (if x = y then Cons 0 else id) (map Suc (find_indices x ys))"
apply (unfold find_indices_def length_Cons, subst upt_conv_Cons, simp)
apply (fold map_Suc_upt, auto simp: filter_map o_def) done

lemma find_indices_snoc [simp]:
  "find_indices x (ys@[y]) = find_indices x ys @ (if x = y then [length ys] else [])"
  by (unfold find_indices_def, auto intro!: filter_cong simp: nth_append)

lemma mem_set_find_indices [simp]: "i  set (find_indices x xs)  i < length xs  xs!i = x"
  by (auto simp: find_indices_def)

lemma distinct_find_indices: "distinct (find_indices x xs)"
  unfolding find_indices_def by simp 

context abelian_monoid begin

definition sumlist
  where "sumlist xs  foldr (⊕) xs 𝟬"
  (* fold is not good as it reverses the list, although the most general locale for monoids with
     ⊕ is already Abelian in Isabelle 2016-1. foldl is OK but it will not simplify Cons. *)

lemma [simp]:
  shows sumlist_Cons: "sumlist (x#xs) = x  sumlist xs"
    and sumlist_Nil: "sumlist [] = 𝟬"
  by (simp_all add: sumlist_def)

lemma sumlist_carrier [simp]:
  assumes "set xs  carrier G" shows "sumlist xs  carrier G"
  using assms by (induct xs, auto)

lemma sumlist_neutral:
  assumes "set xs  {𝟬}" shows "sumlist xs = 𝟬"
proof (insert assms, induct xs)
  case (Cons x xs)
  then have "x = 𝟬" and "set xs  {𝟬}" by auto
  with Cons.hyps show ?case by auto
qed simp

lemma sumlist_append:
  assumes "set xs  carrier G" and "set ys  carrier G"
  shows "sumlist (xs @ ys) = sumlist xs  sumlist ys"
proof (insert assms, induct xs arbitrary: ys)
  case (Cons x xs)
  have "sumlist (xs @ ys) = sumlist xs  sumlist ys"
    using Cons.prems by (auto intro: Cons.hyps)
  with Cons.prems show ?case by (auto intro!: a_assoc[symmetric])
qed auto

lemma sumlist_snoc:
  assumes "set xs  carrier G" and "x  carrier G"
  shows "sumlist (xs @ [x]) = sumlist xs  x"
  by (subst sumlist_append, insert assms, auto)

lemma sumlist_as_finsum:
  assumes "set xs  carrier G" and "distinct xs" shows "sumlist xs = (xset xs. x)"
  using assms by (induct xs, auto intro:finsum_insert[symmetric])

lemma sumlist_map_as_finsum:
  assumes "f : set xs  carrier G" and "distinct xs"
  shows "sumlist (map f xs) = (x  set xs. f x)"
  using assms by (induct xs, auto)

definition summset where "summset M  fold_mset (⊕) 𝟬 M"

lemma summset_empty [simp]: "summset {#} = 𝟬" by (simp add: summset_def)

lemma fold_mset_add_carrier: "a  carrier G  set_mset M  carrier G  fold_mset (⊕) a M  carrier G" 
proof (induct M arbitrary: a)
  case (add x M)
  thus ?case by 
    (subst comp_fun_commute_on.fold_mset_add_mset[of _ "carrier G"], unfold_locales, auto simp: a_lcomm)
qed simp

lemma summset_carrier[intro]: "set_mset M  carrier G  summset M  carrier G" 
  unfolding summset_def by (rule fold_mset_add_carrier, auto)  

lemma summset_add_mset[simp]:
  assumes a: "a  carrier G" and MG: "set_mset M  carrier G"
  shows "summset (add_mset a M) = a  summset M"
  using assms 
  by (auto simp add: summset_def)
   (rule comp_fun_commute_on.fold_mset_add_mset, unfold_locales, auto simp add: a_lcomm)    
 
lemma sumlist_as_summset:
  assumes "set xs  carrier G" shows "sumlist xs = summset (mset xs)"
  by (insert assms, induct xs, auto)

lemma sumlist_rev:
  assumes "set xs  carrier G"
  shows "sumlist (rev xs) = sumlist xs"
  using assms by (simp add: sumlist_as_summset)

lemma sumlist_as_fold:
  assumes "set xs  carrier G"
  shows "sumlist xs = fold (⊕) xs 𝟬"
  by (fold sumlist_rev[OF assms], simp add: sumlist_def foldr_conv_fold)

end

context Module.module begin

definition lincomb_list
where "lincomb_list c vs = sumlist (map (λi. c i M vs ! i) [0..<length vs])"

lemma lincomb_list_carrier:
  assumes "set vs  carrier M" and "c : {0..<length vs}  carrier R"
  shows "lincomb_list c vs  carrier M"
  by (insert assms, unfold lincomb_list_def, intro sumlist_carrier, auto intro!: smult_closed)

lemma lincomb_list_Nil [simp]: "lincomb_list c [] = 𝟬M"
  by (simp add: lincomb_list_def)

lemma lincomb_list_Cons [simp]:
  "lincomb_list c (v#vs) = c 0 M v M lincomb_list (c o Suc) vs"
  by (unfold lincomb_list_def length_Cons, subst upt_conv_Cons, simp, fold map_Suc_upt, simp add: o_def)

lemma lincomb_list_eq_0:
  assumes "i. i < length vs  c i M vs ! i = 𝟬M"
  shows "lincomb_list c vs = 𝟬M"
proof (insert assms, induct vs arbitrary:c)
  case (Cons v vs)
  from Cons.prems[of 0] have [simp]: "c 0 M v = 𝟬M" by auto
  from Cons.prems[of "Suc _"] Cons.hyps have "lincomb_list (c  Suc) vs = 𝟬M" by auto
  then show ?case by (simp add: o_def)
qed simp

definition mk_coeff where "mk_coeff vs c v  R.sumlist (map c (find_indices v vs))"

lemma mk_coeff_carrier:
  assumes "c : {0..<length vs}  carrier R" shows "mk_coeff vs c w  carrier R"
  by (insert assms, auto simp: mk_coeff_def find_indices_def intro!:R.sumlist_carrier elim!:funcset_mem)

lemma mk_coeff_Cons:
  assumes "c : {0..<length (v#vs)}  carrier R"
  shows "mk_coeff (v#vs) c = (λw. (if w = v then c 0 else 𝟬)  mk_coeff vs (c o Suc) w)"
proof-
  from assms have "c o Suc : {0..<length vs}  carrier R" by auto
  from mk_coeff_carrier[OF this] assms
  show ?thesis by (auto simp add: mk_coeff_def find_indices_Cons)
qed

lemma mk_coeff_0[simp]:
  assumes "v  set vs"
  shows "mk_coeff vs c v = 𝟬"
proof -
  have "(find_indices v vs) = []" using assms unfolding find_indices_def
    by (simp add: in_set_conv_nth)
  thus ?thesis  unfolding mk_coeff_def by auto
qed  

lemma lincomb_list_as_lincomb:
  assumes vs_M: "set vs  carrier M" and c: "c : {0..<length vs}  carrier R"
  shows "lincomb_list c vs = lincomb (mk_coeff vs c) (set vs)"
proof (insert assms, induct vs arbitrary: c)
  case (Cons v vs)
  have mk_coeff_Suc_closed: "mk_coeff vs (c  Suc) a  carrier R" for a
    apply (rule mk_coeff_carrier)
    using Cons.prems unfolding Pi_def by auto
  have x_in: "x  carrier M" if x: "x set vs" for x using Cons.prems x by auto
  show ?case apply (unfold mk_coeff_Cons[OF Cons.prems(2)] lincomb_list_Cons)
    apply (subst Cons) using Cons apply (force, force)
  proof (cases "v  set vs", auto simp:insert_absorb)
    case False
    let ?f = "(λva. ((if va = v then c 0 else 𝟬)  mk_coeff vs (c  Suc) va) M va)"
    have mk_0: "mk_coeff vs (c  Suc) v = 𝟬" using False by auto
    have [simp]: "(c 0  𝟬) = c 0"
      using Cons.prems(2) by force
    have finsum_rw: "(Mvainsert v (set vs). ?f va) = (?f v) M (Mva(set vs). ?f va)"
    proof (rule finsum_insert, auto simp add: False, rule smult_closed, rule R.a_closed)
      fix x
      show "mk_coeff vs (c  Suc) x  carrier R" 
        using mk_coeff_Suc_closed by auto
      show "c 0 M v  carrier M"
      proof (rule smult_closed)
        show "c 0  carrier R"
          using Cons.prems(2) by fastforce
        show "v  carrier M"
          using Cons.prems(1) by auto
      qed
      show "𝟬  carrier R"
        by simp
      assume x: "x  set vs" show "x  carrier M"
        using Cons.prems(1) x by auto
    qed
    have finsum_rw2: 
      "(Mva(set vs). ?f va) = (Mvaset vs. (mk_coeff vs (c  Suc) va) M va)"
    proof (rule finsum_cong2, auto simp add: False)
      fix i assume i: "i  set vs"
      have "c  Suc  {0..<length vs}  carrier R" using Cons.prems by auto
      then have [simp]: "mk_coeff vs (c  Suc) i  carrier R" 
        using mk_coeff_Suc_closed by auto
      have "𝟬  mk_coeff vs (c  Suc) i = mk_coeff vs (c  Suc) i" by (rule R.l_zero, simp)
      then show "(𝟬  mk_coeff vs (c  Suc) i) M i = mk_coeff vs (c  Suc) i M i" 
        by auto
      show "(𝟬  mk_coeff vs (c  Suc) i) M i  carrier M"
        using Cons.prems(1) i by auto
    qed
    show "c 0 M v M lincomb (mk_coeff vs (c  Suc)) (set vs) =
    lincomb (λa. (if a = v then c 0 else 𝟬)  mk_coeff vs (c  Suc) a) (insert v (set vs))" 
      unfolding lincomb_def
      unfolding finsum_rw mk_0 
      unfolding finsum_rw2 by auto
  next
    case True
    let ?f = "λva. ((if va = v then c 0 else 𝟬)  mk_coeff vs (c  Suc) va) M va"
    have rw: "(c 0  mk_coeff vs (c  Suc) v) M v 
      = (c 0 M v) M (mk_coeff vs (c  Suc) v) M v"      
      using Cons.prems(1) Cons.prems(2) atLeast0_lessThan_Suc_eq_insert_0 
      using mk_coeff_Suc_closed smult_l_distr by auto
    have rw2: "((mk_coeff vs (c  Suc) v) M v) 
      M (Mva(set vs - {v}). ?f va) = (Mvset vs. mk_coeff vs (c  Suc) v M v)"
    proof -
      have "(Mva(set vs - {v}). ?f va) = (Mvset vs - {v}. mk_coeff vs (c  Suc) v M v)"
        by (rule finsum_cong2, unfold Pi_def, auto simp add: mk_coeff_Suc_closed x_in)
      moreover have "(Mvset vs. mk_coeff vs (c  Suc) v M v) = ((mk_coeff vs (c  Suc) v) M v) 
        M (Mvset vs - {v}. mk_coeff vs (c  Suc) v M v)"
        by (rule M.add.finprod_split, auto simp add: mk_coeff_Suc_closed True x_in)
      ultimately show ?thesis by auto
    qed
    have "lincomb (λa. (if a = v then c 0 else 𝟬)  mk_coeff vs (c  Suc) a) (set vs) 
      = (Mvaset vs. ?f va)" unfolding lincomb_def ..
    also have "... = ?f v M (Mva(set vs - {v}). ?f va)"
    proof (rule M.add.finprod_split)
      have c0_mkcoeff_in: "c 0  mk_coeff vs (c  Suc) v  carrier R" 
      proof (rule R.a_closed)
        show "c 0  carrier R " using Cons.prems by auto
        show "mk_coeff vs (c  Suc) v  carrier R"
          using mk_coeff_Suc_closed by auto
    qed
    moreover have "(𝟬  mk_coeff vs (c  Suc) va) M va  carrier M"
      if va: "va  carrier M" for va 
      by (rule smult_closed[OF _ va], rule R.a_closed, auto simp add: mk_coeff_Suc_closed)
    ultimately show "?f ` set vs  carrier M" using Cons.prems(1) by auto        
      show "finite (set vs)" by simp
      show "v  set vs" using True by simp
    qed
    also have "... = (c 0  mk_coeff vs (c  Suc) v) M v 
      M (Mva(set vs - {v}). ?f va)" by auto
    also have "... = ((c 0 M v) M (mk_coeff vs (c  Suc) v) M v) 
      M (Mva(set vs - {v}). ?f va)" unfolding rw by simp
    also have "... = (c 0 M v) M (((mk_coeff vs (c  Suc) v) M v) 
      M (Mva(set vs - {v}). ?f va))"
    proof (rule M.a_assoc)
      show "c 0 M v  carrier M" 
        using Cons.prems(1) Cons.prems(2) by auto
      show "mk_coeff vs (c  Suc) v M v  carrier M"
        using Cons.prems(1) mk_coeff_Suc_closed by auto
      show "(Mvaset vs - {v}. ((if va = v then c 0 else 𝟬) 
         mk_coeff vs (c  Suc) va) M va)  carrier M"
        by (rule M.add.finprod_closed) (auto simp add: mk_coeff_Suc_closed x_in)
    qed
    also have "... = c 0 M v M (Mvset vs. mk_coeff vs (c  Suc) v M v)"
      unfolding rw2 ..
    also have "... = c 0 M v M lincomb (mk_coeff vs (c  Suc)) (set vs)" 
      unfolding lincomb_def ..
    finally show "c 0 M v M lincomb (mk_coeff vs (c  Suc)) (set vs) 
      = lincomb (λa. (if a = v then c 0 else 𝟬)  mk_coeff vs (c  Suc) a) (set vs)" ..         
  qed
qed simp

definition "span_list vs  {lincomb_list c vs | c. c : {0..<length vs}  carrier R}"

lemma in_span_listI:
  assumes "c : {0..<length vs}  carrier R" and "v = lincomb_list c vs"
  shows "v  span_list vs"
  using assms by (auto simp: span_list_def)

lemma in_span_listE:
  assumes "v  span_list vs"
      and "c. c : {0..<length vs}  carrier R  v = lincomb_list c vs  thesis"
  shows thesis
  using assms by (auto simp: span_list_def)

lemmas lincomb_insert2 = lincomb_insert[unfolded insert_union[symmetric]]

lemma lincomb_zero:
  assumes U: "U  carrier M" and a: "a : U  {zero R}"
  shows "lincomb a U = zero M"
  using U a
proof (induct U rule: infinite_finite_induct)
  case empty show ?case unfolding lincomb_def by auto next
  case (insert u U)
    hence "a  insert u U  carrier R" using zero_closed by force
    thus ?case using insert by (subst lincomb_insert2; auto)
qed (auto simp: lincomb_def)

end

hide_const (open) Multiset.mult
end

Theory VS_Connect

(*  
    Author:      René Thiemann 
                 Akihisa Yamada
    License:     BSD
*)
(* with contributions from Alexander Bentkamp, Universität des Saarlandes *)

section ‹Matrices as Vector Spaces›

text ‹This theory connects the Matrix theory with the VectorSpace theory of
  Holden Lee. As a consequence notions like span, basis, linear dependence, etc. 
  are available for vectors and matrices of the Matrix-theory.›

theory VS_Connect
imports 
  Matrix
  Missing_VectorSpace
  Determinant
begin

hide_const (open) Multiset.mult
hide_const (open) Polynomial.smult
hide_const (open) Modules.module
hide_const (open) subspace
hide_fact (open) subspace_def

named_theorems class_ring_simps

abbreviation class_ring :: "'a :: {times,plus,one,zero} ring" where
  "class_ring   carrier = UNIV, mult = (*), one = 1, zero = 0, add = (+) "

interpretation class_semiring: semiring "class_ring :: 'a :: semiring_1 ring"
  rewrites [class_ring_simps]: "carrier class_ring = UNIV"
    and [class_ring_simps]: "mult class_ring = (*)"
    and [class_ring_simps]: "add class_ring = (+)"
    and [class_ring_simps]: "one class_ring = 1"
    and [class_ring_simps]: "zero class_ring = 0"
    and [class_ring_simps]: "pow (class_ring :: 'a ring) = (^)"
    and [class_ring_simps]: "finsum (class_ring :: 'a ring) = sum"
proof -
  let ?r = "class_ring :: 'a ring"
  show "semiring ?r"
    by (unfold_locales, auto simp: field_simps)
  then interpret semiring ?r .
  {
    fix x y
    have "x [^]?r y = x ^ y"
      by (induct y, auto simp: power_commutes)
  }
  thus "([^]?r) = (^)" by (intro ext)
  {
    fix f and A :: "'b set"
    have "finsum ?r f A = sum f A"
      by (induct A rule: infinite_finite_induct, auto)
  }
  thus "finsum ?r = sum" by (intro ext)
qed auto 

interpretation class_ring: ring "class_ring :: 'a :: ring_1 ring"
  rewrites "carrier class_ring = UNIV"
    and "mult class_ring = (*)"
    and "add class_ring = (+)"
    and "one class_ring = 1"
    and "zero class_ring = 0"
    and [class_ring_simps]: "a_inv (class_ring :: 'a ring) = uminus"
    and [class_ring_simps]: "a_minus (class_ring :: 'a ring) = minus"
    and "pow (class_ring :: 'a ring) = (^)"
    and "finsum (class_ring :: 'a ring) = sum"
proof -
  let ?r = "class_ring :: 'a ring"
  interpret semiring ?r ..
  show "finsum ?r = sum" "pow ?r = (^)" by (simp_all add: class_ring_simps)
  {
    fix x :: 'a
    have "y. x + y = 0" by (rule exI[of _ "-x"], auto)
  } note [simp] = this
  show "ring ?r"
    by (unfold_locales, auto simp: field_simps Units_def)
  then interpret ring ?r .
  {
    fix x :: 'a
    have "?r x = - x" unfolding a_inv_def m_inv_def
      by (rule the1_equality, rule ex1I[of _ "- x"], auto simp: minus_unique)
  } note ainv = this
  thus inv: "a_inv ?r = uminus" by (intro ext)
  {
    fix x y :: 'a
    have "x ?r y = x - y"
      apply (subst a_minus_def)
      using inv by auto
  }
  thus "(λx y. x ?r y) = minus" by (intro ext)
qed (auto simp: class_ring_simps)

interpretation class_cring: cring "class_ring :: 'a :: comm_ring_1 ring"
  rewrites "carrier class_ring = UNIV"
    and "mult class_ring = (*)"
    and "add class_ring = (+)"
    and "one class_ring = 1"
    and "zero class_ring = 0"
    and "a_inv (class_ring :: 'a ring) = uminus"
    and "a_minus (class_ring :: 'a ring) = minus"
    and "pow (class_ring :: 'a ring) = (^)"
    and "finsum (class_ring :: 'a ring) = sum"
    and [class_ring_simps]: "finprod class_ring = prod"
proof -
  let ?r = "class_ring :: 'a ring"
  interpret ring ?r ..
  show "cring ?r"
    by (unfold_locales, auto)
  then interpret cring ?r .
  show "a_inv (class_ring :: 'a ring) = uminus"
    and "a_minus (class_ring :: 'a ring) = minus"
    and "pow (class_ring :: 'a ring) = (^)"
    and "finsum (class_ring :: 'a ring) = sum" by (simp_all add: class_ring_simps)
  {
    fix f and A :: "'b set"
    have "finprod ?r f A = prod f A"
      by (induct A rule: infinite_finite_induct, auto)
  }
  thus "finprod ?r = prod" by (intro ext)
qed (auto simp: class_ring_simps)

definition div0 :: "'a :: {one,plus,times,zero}" where
  "div0  m_inv (class_ring :: 'a ring) 0"

lemma class_field: "field (class_ring :: 'a :: field ring)" (is "field ?r")
proof -
  interpret cring ?r ..
  {
    fix x :: 'a
    have "x  0  xa. xa * x = 1  x * xa = 1"
      by (intro exI[of _ "inverse x"], auto)
  } note [simp] = this
  show "field ?r" 
    by (unfold_locales, auto simp: Units_def)
qed

interpretation class_field: field "class_ring :: 'a :: field ring"
  rewrites "carrier class_ring = UNIV"
    and "mult class_ring = (*)"
    and "add class_ring = (+)"
    and "one class_ring = 1"
    and "zero class_ring = 0"
    and "a_inv class_ring = uminus"
    and "a_minus class_ring = minus"
    and "pow class_ring = (^)"
    and "finsum class_ring = sum"
    and "finprod class_ring = prod"
    and [class_ring_simps]: "m_inv (class_ring :: 'a ring) x = 
      (if x = 0 then div0 else inverse x)" 
    (* problem that m_inv ?r 0 = inverse 0 is not guaranteed  *)
proof -
  let ?r = "class_ring :: 'a ring"
  show "field ?r" using class_field.
  then interpret field ?r.
  show "a_inv ?r = uminus"
    and "a_minus ?r = minus"
    and "pow ?r = (^)"
    and "finsum ?r = sum"
    and "finprod ?r = prod" by (fact class_ring_simps)+
  show "inv?r x = (if x = 0 then div0 else inverse x)"
  proof (cases "x = 0")
    case True
    thus ?thesis unfolding div0_def by simp
  next
    case False
    thus ?thesis unfolding m_inv_def
      by (intro the1_equality ex1I[of _ "inverse x"], auto simp: inverse_unique)
  qed
qed (auto simp: class_ring_simps)

lemmas matrix_vs_simps = module_mat_simps class_ring_simps

definition class_field :: "'a :: field ring"
  where [class_ring_simps]: "class_field  class_ring"




locale matrix_ring = 
  fixes n :: nat
    and field_type :: "'a :: field itself"
begin
abbreviation R where "R  ring_mat TYPE('a) n n"
sublocale ring R
  rewrites "carrier R = carrier_mat n n"
    and "add R = (+)"
    and "mult R = (*)"
    and "one R = 1m n"
    and "zero R = 0m n n"
  using ring_mat by (auto simp: ring_mat_simps)

end

lemma matrix_vs: "vectorspace (class_ring :: 'a :: field ring) (module_mat TYPE('a) nr nc)"
proof -
  interpret abelian_group "module_mat TYPE('a) nr nc"
    by (rule abelian_group_mat)
  show ?thesis unfolding class_field_def
    by (unfold_locales, unfold matrix_vs_simps, 
      auto simp: add_smult_distrib_left_mat add_smult_distrib_right_mat)
qed


locale vec_module =
  fixes f_ty::"'a::comm_ring_1 itself"
  and n::"nat"
begin

abbreviation V where "V  module_vec TYPE('a) n"

sublocale Module.module "class_ring :: 'a ring" V
  rewrites "carrier V = carrier_vec n"
    and "add V = (+)"
    and "zero V = 0v n"
    and "module.smult V = (⋅v)"
    and "carrier class_ring = UNIV"
    and "monoid.mult class_ring = (*)"
    and "add class_ring = (+)"
    and "one class_ring = 1"
    and "zero class_ring = 0"
    and "a_inv (class_ring :: 'a ring) = uminus"
    and "a_minus (class_ring :: 'a ring) = (-)"
    and "pow (class_ring :: 'a ring) = (^)"
    and "finsum (class_ring :: 'a ring) = sum"
    and "finprod (class_ring :: 'a ring) = prod"
    and "X. X  UNIV = True" (* These rewrite rules will clean lemmas *)
    and "x. x  UNIV = True"
    and "a A. a  A  UNIV  True"
    and "P. P  True  P"
    and "P. (True  P)  Trueprop P"
  apply unfold_locales
  apply (auto simp: module_vec_simps class_ring_simps Units_def add_smult_distrib_vec 
      smult_add_distrib_vec intro!:bexI[of _ "- _"])
  done

end

locale matrix_vs = 
  fixes nr :: nat
    and nc :: nat
    and field_type :: "'a :: field itself"
begin

abbreviation V where "V  module_mat TYPE('a) nr nc"
sublocale  
  vectorspace class_ring V
  rewrites "carrier V = carrier_mat nr nc"
    and "add V = (+)"
    and "mult V = (*)"
    and "one V = 1m nr"
    and "zero V = 0m nr nc"
    and "smult V = (⋅m)"
    and "carrier class_ring = UNIV"
    and "mult class_ring = (*)"
    and "add class_ring = (+)"
    and "one class_ring = 1"
    and "zero class_ring = 0"
    and "a_inv (class_ring :: 'a ring) = uminus"
    and "a_minus (class_ring :: 'a ring) = minus"
    and "pow (class_ring :: 'a ring) = (^)"
    and "finsum (class_ring :: 'a ring) = sum"
    and "finprod (class_ring :: 'a ring) = prod"
    and "m_inv (class_ring :: 'a ring) x = 
      (if x = 0 then div0 else inverse x)"
  by (rule matrix_vs, auto simp: matrix_vs_simps class_field_def)
end

lemma vec_module: "module (class_ring :: 'a :: field ring) (module_vec TYPE('a) n)"
proof -
  interpret abelian_group "module_vec TYPE('a) n"
    apply (unfold_locales)
    unfolding module_vec_def Units_def
    using add_inv_exists_vec by auto
  show ?thesis
    unfolding class_field_def
    apply (unfold_locales)
    unfolding class_ring_simps
    unfolding module_vec_simps
    using add_smult_distrib_vec
    by (auto simp: smult_add_distrib_vec)
qed

lemma vec_vs: "vectorspace (class_ring :: 'a :: field ring) (module_vec TYPE('a) n)"
  unfolding vectorspace_def
  using vec_module class_field 
  by (auto simp: class_field_def)

locale vec_space =
  fixes f_ty::"'a::field itself"
  and n::"nat"
begin

  sublocale vec_module f_ty n.

  sublocale vectorspace class_ring V 
  rewrites cV[simp]: "carrier V = carrier_vec n"
    and [simp]: "add V = (+)"
    and [simp]: "zero V = 0v n"
    and [simp]: "smult V = (⋅v)"
    and "carrier class_ring = UNIV"
    and "mult class_ring = (*)"
    and "add class_ring = (+)"
    and "one class_ring = 1"
    and "zero class_ring = 0"
    and "a_inv (class_ring :: 'a ring) = uminus"
    and "a_minus (class_ring :: 'a ring) = minus"
    and "pow (class_ring :: 'a ring) = (^)"
    and "finsum (class_ring :: 'a ring) = sum"
    and "finprod (class_ring :: 'a ring) = prod"
    and "m_inv (class_ring :: 'a ring) x = (if x = 0 then div0 else inverse x)"
  using vec_vs
  unfolding class_field_def
  by (auto simp: module_vec_simps class_ring_simps)

lemma finsum_vec[simp]: "finsum_vec TYPE('a) n = finsum V"
  by (force simp: finsum_vec_def monoid_vec_def finsum_def finprod_def)

lemma finsum_scalar_prod_sum:
  assumes f: "f : U  carrier_vec n"
      and w: "w: carrier_vec n"
  shows "finsum V f U  w = sum (λu. f u  w) U"
  using w f
proof (induct U rule: infinite_finite_induct)
  case (insert u U)
    hence f: "f : U  carrier_vec n" "f u : carrier_vec n"  by auto
    show ?case
      unfolding finsum_insert[OF insert(1) insert(2) f]
      apply (subst add_scalar_prod_distrib) using insert by auto
qed auto

lemma vec_neg[simp]: assumes "x : carrier_vec n" shows "V x = - x"
  unfolding a_inv_def m_inv_def apply simp 
  apply (rule the_equality, intro conjI)
  using assms apply auto
  using M.minus_unique uminus_carrier_vec uminus_r_inv_vec by blast

lemma finsum_dim:
  "finite A  f  A  carrier_vec n  dim_vec (finsum V f A) = n"
proof(induct set:finite)
  case (insert a A)
    hence dfa: "dim_vec (f a) = n" by auto
    have f: "f  A  carrier_vec n" using insert by auto
    hence fa: "f a  carrier_vec n" using insert by auto
    show ?case
      unfolding finsum_insert[OF insert(1) insert(2) f fa]
      using insert by auto
qed simp

lemma lincomb_dim:
  assumes fin: "finite X"
    and X: "X  carrier_vec n"
  shows "dim_vec (lincomb a X) = n"
proof -
  let ?f = "λv. a v v v"
  have f: "?f  X  carrier_vec n" apply rule using X by auto
  show ?thesis
    unfolding lincomb_def
    using finsum_dim[OF fin f].
qed


lemma finsum_index:
  assumes i: "i < n"
    and f: "f  X  carrier_vec n"
    and X: "X  carrier_vec n"
  shows "finsum V f X $ i = sum (λx. f x $ i) X"
  using X f
proof (induct X rule: infinite_finite_induct)
  case empty
    then show ?case using i by simp next
  case (insert x X)
    then have Xf: "finite X"
      and xX: "x  X"
      and x: "x  carrier_vec n"
      and X: "X  carrier_vec n"
      and fx: "f x  carrier_vec n"
      and f: "f  X  carrier_vec n" by auto
    have i2: "i < dim_vec (finsum V f X)"
      using i finsum_closed[OF f] by auto
    have ix: "i < dim_vec x" using x i by auto
    show ?case
      unfolding finsum_insert[OF Xf xX f fx]
      unfolding sum.insert[OF Xf xX]
      unfolding index_add_vec(1)[OF i2]
      using insert lincomb_def
      by auto
qed (insert i, auto)

lemma lincomb_index:
  assumes i: "i < n"
    and X: "X  carrier_vec n"
  shows "lincomb a X $ i = sum (λx. a x * x $ i) X"
proof -
  let ?f = "λx. a x v x"
  have f: "?f : X  carrier_vec n" using X by auto
  have point: "v. v  X  (a v v v) $ i = a v * v $ i" using i X by auto
  show ?thesis
    unfolding lincomb_def
    unfolding finsum_index[OF i f X]
    using point X by simp
qed

lemma append_insert: "set (xs @ [x]) = insert x (set xs)" by simp

lemma lincomb_units:
  assumes i: "i < n" 
  shows "lincomb a (set (unit_vecs n)) $ i = a (unit_vec n i)"
  unfolding lincomb_index[OF i unit_vecs_carrier]
  unfolding unit_vecs_first
proof -
  let ?f = "λm i. xset (unit_vecs_first n m). a x * x $ i"
  have zero:"m j. m  j  j < n  ?f m j = 0"
  proof -
    fix m
    show "j. m  j  j < n  ?f m j = 0"
    proof (induction m)
      case (Suc m)
        hence mj:"mj" and mj':"mj" and jn:"j<n" and mn:"m<n" by auto
        hence mem: "unit_vec n m  set (unit_vecs_first n m)"
          apply(subst unit_vecs_first_distinct) by auto
        show ?case
          unfolding unit_vecs_first.simps
          unfolding append_insert
          unfolding sum.insert[OF finite_set mem]
          unfolding index_unit_vec(1)[OF mn jn]
          unfolding Suc(1)[OF mj jn] using mj' by simp
    qed simp
  qed
  { fix m
    have "i < m  m  n  ?f m i = a (unit_vec n i)"
    proof (induction m arbitrary: i)
      case (Suc m)
        hence iSm: "i < Suc m" and i:"i<n" and mn: "m<n" by auto
        hence mem: "unit_vec n m  set (unit_vecs_first n m)"
          apply(subst unit_vecs_first_distinct) by auto
        show ?case
          unfolding unit_vecs_first.simps
          unfolding append_insert
          unfolding sum.insert[OF finite_set mem]
          unfolding index_unit_vec(1)[OF mn i]
          using zero Suc by (cases "i = m",auto)
    qed auto
  }
  thus "?f n i = a (unit_vec n i)" using assms by auto
qed

lemma lincomb_coordinates:
  assumes v: "v : carrier_vec n"
  defines "a  (λu. v $ (THE i. u = unit_vec n i))"
  shows "lincomb a (set (unit_vecs n)) = v"
proof -
  have a: "a  set (unit_vecs n)  UNIV" by auto
  have fvu: "i. i < n  v $ i = a (unit_vec n i)"
    unfolding a_def using unit_vec_eq by auto
  show ?thesis
    apply rule
    unfolding lincomb_dim[OF finite_set unit_vecs_carrier]
    using v lincomb_units fvu
    by auto
qed

lemma span_unit_vecs_is_carrier: "span (set (unit_vecs n)) = carrier_vec n" (is "?L = ?R")
proof (rule;rule)
  fix v assume vsU: "v  ?L" show "v  ?R"
  proof -
    obtain a
      where v: "v = lincomb a (set (unit_vecs n))"
      using vsU
      unfolding finite_span[OF finite_set unit_vecs_carrier] by auto
    thus ?thesis using lincomb_closed[OF unit_vecs_carrier] by auto
  qed
  next fix v::"'a vec" assume v: "v  ?R" show "v  ?L"
    unfolding span_def
    using lincomb_coordinates[OF v,symmetric] by auto
qed

lemma fin_dim[simp]: "fin_dim"
  unfolding fin_dim_def
  apply (intro eqTrueI exI conjI) using span_unit_vecs_is_carrier unit_vecs_carrier by auto

lemma unit_vecs_basis: "basis (set (unit_vecs n))" unfolding basis_def span_unit_vecs_is_carrier
proof (intro conjI)
  show "¬ lin_dep (set (unit_vecs n))" 
  proof
    assume "lin_dep (set (unit_vecs n))"
    from this[unfolded lin_dep_def] obtain A a v where
      fin: "finite A" and A: "A  set (unit_vecs n)"  
      and lc: "lincomb a A = 0v n" and v: "v  A" and av: "a v  0"      
      by auto
    from v A obtain i where i: "i < n" and vu: "v = unit_vec n i" unfolding unit_vecs_def by auto
    define b where "b = (λ x. if x  A then a x else 0)"
    have id: "A  (set (unit_vecs n) - A) = set (unit_vecs n)" using A by auto
    from lincomb_index[OF i unit_vecs_carrier]
    have "lincomb b (set (unit_vecs n)) $ i = (x (A  (set (unit_vecs n) - A)). b x * x $ i)" 
      unfolding id .
    also have " = (x A. b x * x $ i) + (x set (unit_vecs n) - A. b x * x $ i)"
      by (rule sum.union_disjoint, insert fin, auto)
    also have "(x A. b x * x $ i) = (x A. a x * x $ i)"
      by (rule sum.cong, auto simp: b_def)
    also have " = lincomb a A $ i" 
      by (subst lincomb_index[OF i], insert A unit_vecs_carrier, auto)
    also have " = 0" unfolding lc using i by simp
    also have "(x set (unit_vecs n) - A. b x * x $ i) = 0"
      by (rule sum.neutral, auto simp: b_def)
    finally have "lincomb b (set (unit_vecs n)) $ i = 0" by simp
    from lincomb_units[OF i, of b, unfolded this]
    have "b v = 0" unfolding vu by simp
    with v av show False unfolding b_def by simp
  qed
qed (insert unit_vecs_carrier, auto)

lemma unit_vecs_length[simp]: "length (unit_vecs n) = n"
  unfolding unit_vecs_def by auto

lemma unit_vecs_distinct: "distinct (unit_vecs n)"
  unfolding distinct_conv_nth unit_vecs_length
proof (intro allI impI)
  fix i j
  assume *: "i < n" "j < n" "i  j"
  show "unit_vecs n ! i  unit_vecs n ! j"
  proof
    assume "unit_vecs n ! i = unit_vecs n ! j"
    from arg_cong[OF this, of "λ v. v $ i"] 
    show False using * unfolding unit_vecs_def by auto
  qed
qed

lemma dim_is_n: "dim = n"
  unfolding dim_basis[OF finite_set unit_vecs_basis]
  unfolding distinct_card[OF unit_vecs_distinct]
  by simp

end

locale mat_space =
  vec_space f_ty nc for f_ty::"'a::field itself" and nc::"nat" +
  fixes nr :: "nat"
begin
  abbreviation M where "M  ring_mat TYPE('a) nc nr"
end

context vec_space
begin
lemma fin_dim_span:
assumes "finite A" "A  carrier V"
shows "vectorspace.fin_dim class_ring (vs (span A))"
proof -
  have "vectorspace class_ring (span_vs A)"
   using assms span_is_subspace subspace_def subspace_is_vs by simp
  have "A  span A" using assms in_own_span by simp
  have "submodule class_ring (span A) V" using assms span_is_submodule by simp
  have "LinearCombinations.module.span class_ring (vs (span A)) A = carrier (vs (span A))"
    using  span_li_not_depend(1)[OF A  span A ‹submodule class_ring (span A) V›] by auto
  then show ?thesis unfolding vectorspace.fin_dim_def[OF ‹vectorspace class_ring (span_vs A)]
    using List.finite_set A  span A ‹vectorspace class_ring (vs (span A))
    vec_vs vectorspace.carrier_vs_is_self[OF ‹vectorspace class_ring (span_vs A)] using assms(1) by auto
qed

lemma fin_dim_span_cols:
assumes "A  carrier_mat n nc"
shows "vectorspace.fin_dim class_ring (vs (span (set (cols A))))"
  using fin_dim_span cols_dim List.finite_set assms carrier_matD(1) module_vec_simps(3) by force
end

context vec_module
begin

lemma lincomb_list_as_mat_mult:
  assumes "w  set ws. dim_vec w = n"
  shows "lincomb_list c ws = mat_of_cols n ws *v vec (length ws) c" (is "?l ws c = ?r ws c")
proof (insert assms, induct ws arbitrary: c)
  case Nil
  then show ?case by (auto simp: mult_mat_vec_def scalar_prod_def)
next
  case (Cons w ws)
  { fix i assume i: "i < n"
    have "?l (w#ws) c = c 0 v w + mat_of_cols n ws *v vec (length ws) (c  Suc)"
      by (simp add: Cons o_def)
    also have " $ i = ?r (w#ws) c $ i"
      using Cons i index_smult_vec
      by (simp add: mat_of_cols_Cons_index_0 mat_of_cols_Cons_index_Suc o_def vec_Suc mult_mat_vec_def row_def length_Cons)
    finally have "?l (w#ws) c $ i = ".
  }
  with Cons show ?case by (intro eq_vecI, auto)
qed

lemma lincomb_vec_diff_add:
    assumes A: "A  carrier_vec n"
    and BA: "B  A" and fin_A: "finite A" 
    and f: "f  A  UNIV" shows "lincomb f A = lincomb f (A-B) + lincomb f B"
proof -
  have "A - B  B = A" using BA by auto
  hence "lincomb f A = lincomb f (A - B  B)"  by simp
  also have "... = lincomb f (A-B) + lincomb f B"
    by (rule lincomb_union, insert assms, auto intro: finite_subset)
  finally show ?thesis .
qed

lemma dim_sumlist:
  assumes "xset xs. dim_vec x = n"
  shows "dim_vec (M.sumlist xs) = n" using assms by (induct xs, auto)

lemma sumlist_nth:
  assumes "xset xs. dim_vec x = n" and "i<n"
  shows "(M.sumlist xs) $ i= sum (λj. (xs ! j) $ i) {0..<length xs}"
  using assms
proof (induct xs rule: rev_induct)
  case (snoc a xs) 
  have [simp]: "x  carrier_vec n" if x: "xset xs" for x 
    using snoc.prems x unfolding carrier_vec_def by auto
  have [simp]: "a  carrier_vec n" 
    using snoc.prems unfolding carrier_vec_def by auto
  have hyp: "M.sumlist xs $ i = (j = 0..<length xs. xs ! j $ i)" 
    by (rule snoc.hyps, auto simp add: snoc.prems)  
  have "M.sumlist (xs @ [a]) = M.sumlist xs + M.sumlist [a]" 
    by (rule M.sumlist_append, auto simp add: snoc.prems)
  also have "... = M.sumlist xs + a" by auto
  also have "... $ i = (M.sumlist xs $ i) + (a $ i)" 
    by (rule index_add_vec(1), auto simp add: snoc.prems)
  also have "... =  (j = 0..<length xs. xs ! j $ i) + (a $ i)" unfolding hyp by simp
  also have "... = (j = 0..<length (xs @ [a]). (xs @ [a]) ! j $ i)"
    by (auto, rule sum.cong, auto simp add: nth_append)     
  finally show ?case .
qed auto

lemma lincomb_as_lincomb_list_distinct:
  assumes s: "set ws  carrier_vec n" and d: "distinct ws"
  shows "lincomb f (set ws) = lincomb_list (λi. f (ws ! i)) ws"
proof (insert assms, induct ws)
  case Nil
  then show ?case by auto
next
  case (Cons a ws)
  have [simp]: "v. v  set ws  v  carrier_vec n" using Cons.prems(1) by auto
  then have ws: "set ws  carrier_vec n" by auto
  have hyp: "lincomb f (set (ws)) = lincomb_list (λi. f (ws ! i)) ws"
  proof (intro Cons.hyps ws)
    show "distinct ws" using Cons.prems(2) by auto
  qed  
  have "(map (λi. f (ws ! i) v ws ! i) [0..<length ws]) = (map (λv. f v v v) ws)"
    by (intro nth_equalityI, auto)
  with ws have sumlist_rw: "sumlist (map (λi. f (ws ! i) v ws ! i) [0..<length ws])
    = sumlist (map (λv. f v v v) ws)"
    by (subst (1 2) sumlist_as_summset, auto)
  have "lincomb f (set (a # ws)) = (Vvset (a # ws). f v v v)" unfolding lincomb_def ..
  also have "... = (Vv insert a (set ws). f v v v)" by simp
  also have "... = (f a v a) + (Vv (set ws). f v v v)"
    by (rule finsum_insert, insert Cons.prems, auto)
  also have "... = f a v a + lincomb_list (λi. f (ws ! i)) ws" using hyp lincomb_def by auto
  also have "... = f a v a + sumlist (map (λv. f v v v) ws)" 
    unfolding lincomb_list_def sumlist_rw by auto
  also have "... = sumlist (map (λv. f v v v) (a # ws))"
  proof -
    let ?a = "(map (λv. f v v v) [a])"
    have a: "a  carrier_vec n" using Cons.prems(1) by auto
    have "f a v a = sumlist (map (λv. f v v v) [a])" using Cons.prems(1) by auto
    hence "f a v a + sumlist (map (λv. f v v v) ws) 
      = sumlist ?a + sumlist (map (λv. f v v v) ws)" by simp
    also have "... = sumlist (?a @ (map (λv. f v v v) ws))"
      by (rule sumlist_append[symmetric], auto simp add: a)
    finally show ?thesis by auto
  qed
  also have "... = sumlist (map (λi. f ((a # ws) ! i) v (a # ws) ! i) [0..<length (a # ws)])"
  proof -
    have u: "(map (λi. f ((a # ws) ! i) v (a # ws) ! i) [0..<length (a # ws)]) 
        = (map (λv. f v v v) (a # ws))"
    proof (intro nth_equalityI, goal_cases)
      case (2 i) thus ?case by (smt length_map map_nth nth_map)
    qed auto
    show ?thesis unfolding u ..
  qed
  also have "... = lincomb_list (λi. f ((a # ws) ! i)) (a # ws)"
    unfolding lincomb_list_def ..
  finally show ?case .
qed

end

locale idom_vec = vec_module f_ty for f_ty :: "'a :: idom itself"
begin

lemma lin_dep_cols_imp_det_0':
  fixes ws
  defines "A  mat_of_cols n ws"
  assumes dimv_ws: "wset ws. dim_vec w = n"
  assumes A: "A  carrier_mat n n" and ld_cols: "lin_dep (set (cols A))"
  shows  "det A = 0"
proof (cases "distinct ws")
  case False
  obtain i j where ij: "ij" and c: "col A i = col A j" and i: "i<n" and j: "j<n" 
    using False A unfolding A_def
    by (metis dimv_ws distinct_conv_nth carrier_matD(2) 
        col_mat_of_cols mat_of_cols_carrier(3) nth_mem carrier_vecI)
  show ?thesis by (rule det_identical_columns[OF A ij i j c])  
next
  case True
  have d1[simp]: "x. x  set ws  x  carrier_vec n" using dimv_ws by auto 
  obtain A' f' v where f'_in: "f'  A'  UNIV" 
    and lc_f': "lincomb f' A' = 0v n" and f'_v: "f' v  0"
    and v_A': "v  A'" and A'_in_rows: "A'  set (cols A)" 
    using ld_cols unfolding lin_dep_def by auto
  define f where "f  λx. if x  A' then 0 else f' x"
  have f_in: "f  (set (cols A))  UNIV" using f'_in by auto
  have A'_in_carrier: "A'  carrier_vec n"
    by (metis (no_types) A'_in_rows A_def cols_dim carrier_matD(1) mat_of_cols_carrier(1) subset_trans)
  have lc_f: "lincomb f (set (cols A)) = 0v n"   
  proof -
    have l1: "lincomb f (set (cols A) - A') = 0v n"
      by (rule lincomb_zero, auto simp add: f_def, insert A cols_dim, blast)
    have l2: "lincomb f A' = 0v n " using lc_f' unfolding f_def using A'_in_carrier by auto
    have "lincomb f (set (cols A)) = lincomb f (set (cols A) - A') + lincomb f A'"
    proof (rule lincomb_vec_diff_add)
      show "set (cols A)  carrier_vec n"
        using A cols_dim by blast
      show "A'  set (cols A)"
        using A'_in_rows by blast
    qed auto
    also have "... =  0v n" using l1 l2 by auto
    finally show ?thesis .
  qed
  have v_in: "v  (set (cols A))" using v_A' A'_in_rows by auto 
  have fv: "f v  0" using f'_v v_A' unfolding f_def by auto
  let ?c = "(λi. f (ws ! i))"
  have "lincomb f (set ws) = lincomb_list ?c ws"
    by (rule lincomb_as_lincomb_list_distinct[OF _ True], auto)
  have "v.  v  carrier_vec n  v  0v n  A *v v = 0v n"
  proof (rule exI[of _ " vec (length ws) ?c"], rule conjI)
    show "vec (length ws) ?c  carrier_vec n" using A A_def by auto
    have vec_not0: "vec (length ws) ?c  0v n"
    proof -
      obtain i where ws_i: "(ws ! i) = v" and i: "i<length ws" using v_in unfolding A_def        
        by (metis d1 cols_mat_of_cols in_set_conv_nth subset_eq)
      have "vec (length ws) ?c $ i = ?c i" by (rule index_vec[OF i])
      also have "... = f v" using ws_i by simp
      also have "...  0" using fv by simp
      finally show ?thesis
        using A A_def i by fastforce
    qed
    have "A *v vec (length ws) ?c = mat_of_cols n ws *v vec (length ws) ?c" unfolding A_def ..
    also have "... = lincomb_list ?c ws" by (rule lincomb_list_as_mat_mult[symmetric, OF dimv_ws])
    also have "... = lincomb f (set ws)" 
      by (rule lincomb_as_lincomb_list_distinct[symmetric, OF _ True], auto)
    also have "... =  0v n" 
      using lc_f unfolding A_def using A by (simp add: subset_code(1))
    finally show "vec (length ws) (λi. f (ws ! i))  0v n  A *v vec (length ws) (λi. f (ws ! i)) = 0v n"
      using vec_not0 by fast
  qed 
  thus ?thesis unfolding det_0_iff_vec_prod_zero[OF A] .
qed

lemma lin_dep_cols_imp_det_0:
  assumes A: "A  carrier_mat n n" and ld: "lin_dep (set (cols A))"
  shows "det A = 0" 
proof -
  have col_rw: "(cols (mat_of_cols n (cols A))) = cols A"
    using A by auto
  have m: "mat_of_cols n (cols A) = A" using A by auto
  show ?thesis
  by (rule A lin_dep_cols_imp_det_0'[of "cols A", unfolded col_rw, unfolded m, OF _ A ld])
     (metis A cols_dim carrier_matD(1) subsetCE carrier_vecD)
qed

corollary lin_dep_rows_imp_det_0:
  assumes A: "A  carrier_mat n n" and ld: "lin_dep (set (rows A))"
  shows "det A = 0" 
  by (subst det_transpose[OF A, symmetric], rule lin_dep_cols_imp_det_0, auto simp add: ld A)

lemma det_not_0_imp_lin_indpt_rows:
  assumes A: "A  carrier_mat n n" and det: "det A  0"  
  shows "lin_indpt (set (rows A))"
    using lin_dep_rows_imp_det_0[OF A] det by auto

lemma upper_triangular_imp_lin_indpt_rows:
  assumes A: "A  carrier_mat n n"
    and tri: "upper_triangular A"
    and diag: "0  set (diag_mat A)"
  shows "lin_indpt (set (rows A))"
  using det_not_0_imp_lin_indpt_rows upper_triangular_imp_det_eq_0_iff assms
  by auto

(* Connection from set-based to list-based *)

lemma lincomb_as_lincomb_list:
  fixes ws f
  assumes s: "set ws  carrier_vec n"
  shows "lincomb f (set ws) = lincomb_list (λi. if j<i. ws!i = ws!j then 0 else f (ws ! i)) ws"
  using assms
proof (induct ws rule: rev_induct)
  case (snoc a ws)
  let ?f = "λi. if j<i. ws ! i = ws ! j then 0 else f (ws ! i)"
  let ?g = "λi. (if j<i. (ws @ [a]) ! i = (ws @ [a]) ! j then 0 else f ((ws @ [a]) ! i)) v (ws @ [a]) ! i"
  let ?g2= "(λi. (if j<i. ws ! i = ws ! j then 0 else f (ws ! i)) v ws ! i)"
  have [simp]: "v. v  set ws  v  carrier_vec n" using snoc.prems(1) by auto
  then have ws: "set ws  carrier_vec n" by auto
  have hyp: "lincomb f (set ws) = lincomb_list ?f ws"
    by (intro snoc.hyps ws)  
  show ?case
  proof (cases "aset ws")
    case True    
    have g_length: "?g (length ws) = 0v n" using True
      by (auto, metis in_set_conv_nth nth_append)
    have "(map ?g [0..<length (ws @ [a])]) = (map ?g [0..<length ws]) @ [?g (length ws)]"
       by auto
    also have "... = (map ?g [0..<length ws]) @ [0v n]" using g_length by simp
    finally have map_rw: "(map ?g [0..<length (ws @ [a])]) = (map ?g [0..<length ws]) @ [0v n]" .
    have "M.sumlist (map ?g2 [0..<length ws]) = M.sumlist (map ?g [0..<length ws])"
      by (rule arg_cong[of _ _ "M.sumlist"], intro nth_equalityI, auto simp add: nth_append)
    also have "... =  M.sumlist (map ?g [0..<length ws]) + 0v n "
      by (metis M.r_zero calculation hyp lincomb_closed lincomb_list_def ws)
    also have "... = M.sumlist (map ?g [0..<length ws] @ [0v n])" 
      by (rule M.sumlist_snoc[symmetric], auto simp add: nth_append)
    finally have summlist_rw: "M.sumlist (map ?g2 [0..<length ws]) 
      = M.sumlist (map ?g [0..<length ws] @ [0v n])" .
    have "lincomb f (set (ws @ [a])) = lincomb f (set ws)" using True unfolding lincomb_def
      by (simp add: insert_absorb)
    thus ?thesis 
      unfolding hyp lincomb_list_def map_rw summlist_rw
      by auto
  next
    case False    
    have g_length: "?g (length ws) = f a v a" using False by (auto simp add: nth_append)
    have "(map ?g [0..<length (ws @ [a])]) = (map ?g [0..<length ws]) @ [?g (length ws)]"
       by auto
    also have "... = (map ?g [0..<length ws]) @ [(f a v a)]" using g_length by simp
    finally have map_rw: "(map ?g [0..<length (ws @ [a])]) = (map ?g [0..<length ws]) @ [(f a v a)]" .
    have summlist_rw: "M.sumlist (map ?g2 [0..<length ws]) = M.sumlist (map ?g [0..<length ws])"
      by (rule arg_cong[of _ _ "M.sumlist"], intro nth_equalityI, auto simp add: nth_append)
    have "lincomb f (set (ws @ [a])) = lincomb f (set (a # ws))" by auto
    also have "... = (Vvset (a # ws). f v v v)" unfolding lincomb_def ..
    also have "... = (Vv insert a (set ws). f v v v)" by simp    
    also have "... = (f a v a) + (Vv (set ws). f v v v)"
    proof (rule finsum_insert)
      show "finite (set ws)" by auto
      show "a  set ws" using False by auto
      show "(λv. f v v v)  set ws  carrier_vec n"
        using snoc.prems(1) by auto
      show "f a v a  carrier_vec n" using snoc.prems by auto
    qed
    also have "... = (f a v a) + lincomb f (set ws)" unfolding lincomb_def ..
    also have "... = (f a v a) + lincomb_list ?f ws" using hyp by auto
    also have "... =  lincomb_list ?f ws  + (f a v a)"
      using M.add.m_comm lincomb_list_carrier snoc.prems by auto
    also have "... = lincomb_list (λi. if j<i. (ws @ [a]) ! i 
      = (ws @ [a]) ! j then 0 else f ((ws @ [a]) ! i)) (ws @ [a])" 
    proof (unfold lincomb_list_def map_rw summlist_rw, rule M.sumlist_snoc[symmetric])
      show "set (map ?g [0..<length ws])  carrier_vec n" using snoc.prems
        by (auto simp add: nth_append)
      show "f a v a  carrier_vec n"
        using snoc.prems by auto
    qed
    finally show ?thesis .
  qed
qed auto

lemma span_list_as_span:
  assumes "set vs  carrier_vec n"
  shows "span_list vs = span (set vs)"
  using assms
proof (auto simp: span_list_def span_def)
  fix f show "a A. lincomb_list f vs = lincomb a A  finite A  A  set vs" 
    using assms lincomb_list_as_lincomb by auto
next
  fix f::"'a vec 'a" and A assume fA: "finite A" and A: "A  set vs" 
  have [simp]: "x  carrier_vec n" if x: "x  A" for x using A x assms by auto
  have [simp]:  "v  carrier_vec n" if v: "v  set vs" for v using assms v by auto
  have set_vs_Un: "((set vs) - A)  A = set vs" using A by auto
  let ?f = "(λx. if x(set vs) - A then 0 else f x)"
  have f0: "(Vv(set vs) - A. ?f v v v) = 0v n" by (rule M.finsum_all0, auto)  
  have "lincomb f A = lincomb ?f A"
    by (auto simp add: lincomb_def intro!: finsum_cong2)
  also have "... = (Vv(set vs) - A. ?f v v v) + (VvA. ?f v v v)" 
    unfolding f0 lincomb_def by auto
  also have "... = lincomb ?f (((set vs) - A)  A)" 
    unfolding lincomb_def 
    by (rule M.finsum_Un_disjoint[symmetric], auto simp add: fA)
  also have "... = lincomb ?f (set vs)" using set_vs_Un by auto
  finally have "lincomb f A = lincomb ?f (set vs)" .    
  with lincomb_as_lincomb_list[OF assms] 
  show "c. lincomb f A = lincomb_list c vs" by auto    
qed

lemma in_spanI[intro]:
  assumes "v = lincomb a A" "finite A" "A  W"
  shows "v  span W"
unfolding span_def using assms by auto
lemma in_spanE:
  assumes "v  span W"
  shows " a A. v = lincomb a A  finite A  A  W"
using assms unfolding span_def by auto

declare in_own_span[intro]

lemma smult_in_span:
  assumes "W  carrier_vec n" and insp: "x  span W"
  shows "c v x  span W"
proof -
  from in_spanE[OF insp] obtain a A where a: "x = lincomb a A" "finite A" "A  W" by blast
  have "c v x = lincomb (λ x. c * a x) A" using a(1) unfolding lincomb_def a
    apply(subst finsum_smult) using assms a by (auto simp:smult_smult_assoc)
  thus "c v x  span W" using a(2,3) by auto
qed

lemma span_subsetI: assumes ws: "ws  carrier_vec n" 
  "us  span ws" 
shows "span us  span ws" 
  by (simp add: assms(1) span_is_submodule span_is_subset subsetI ws)

end

context vec_space begin
sublocale idom_vec.

lemma sumlist_in_span: assumes W: "W  carrier_vec n"  
  shows "(x. x  set xs  x  span W)  sumlist xs  span W" 
proof (induct xs)
  case Nil
  thus ?case using W by force
next
  case (Cons x xs)
  from span_is_subset2[OF W] Cons(2) have xs: "x  carrier_vec n" "set xs  carrier_vec n" by auto
  from span_add1[OF W Cons(2)[of x] Cons(1)[OF Cons(2)]]
  have "x + sumlist xs  span W" by auto
  also have "x + sumlist xs = sumlist ([x] @ xs)" 
    by (subst sumlist_append, insert xs, auto)
  finally show ?case by simp
qed

lemma span_span[simp]:
  assumes "W  carrier_vec n"
  shows "span (span W) = span W"
proof(standard,standard,goal_cases)
  case (1 x) with in_spanE obtain a A where a: "x = lincomb a A" "finite A" "A  span W" by blast
  from a(3) assms have AC:"A  carrier_vec n" by auto
  show ?case unfolding a(1)[unfolded lincomb_def]
  proof(insert a(3),atomize (full),rule finite_induct[OF a(2)],goal_cases)
    case 1
    then show ?case using span_zero by auto
  next
    case (2 x F)
    { assume F:"insert x F  span W"
      hence "a x v x  span W" by (intro smult_in_span[OF assms],auto)
      hence "a x v x + (VvF. a v v v)  span W"
        using span_add1 F 2 assms by auto
      hence "(Vvinsert x F. a v v v)  span W"
        apply(subst M.finsum_insert[OF 2(1,2)]) using F assms by auto
    }
    then show ?case by auto
  qed
next
  case 2
  show ?case using assms by(intro in_own_span, auto)
qed


lemma upper_triangular_imp_basis:
  assumes A: "A  carrier_mat n n"
    and tri: "upper_triangular A"
    and diag: "0  set (diag_mat A)"
  shows "basis (set (rows A))"
  using upper_triangular_imp_distinct[OF assms]
  using upper_triangular_imp_lin_indpt_rows[OF assms] A
  by (auto intro: dim_li_is_basis simp: distinct_card dim_is_n set_rows_carrier)

lemma fin_dim_span_rows:
assumes A: "A  carrier_mat nr n"
shows "vectorspace.fin_dim class_ring (vs (span (set (rows A))))"
proof (rule fin_dim_span) 
  show "set (rows A)  carrier V" using A rows_carrier[of A] unfolding carrier_mat_def by auto
  show "finite (set (rows A))" by auto
qed

definition "row_space B = span (set (rows B))"
definition "col_space B = span (set (cols B))"

lemma row_space_eq_col_space_transpose:
  shows "row_space A = col_space AT"
  unfolding col_space_def row_space_def cols_transpose ..

lemma col_space_eq_row_space_transpose:
  shows "col_space A = row_space AT"
  unfolding col_space_def row_space_def Matrix.rows_transpose ..


lemma col_space_eq:
  assumes A: "A  carrier_mat n nc"
  shows "col_space A = {ycarrier_vec (dim_row A). xcarrier_vec (dim_col A). A *v x = y}"
proof -
  let ?ws = "cols A"
  have set_cols_in: "set (cols A)  carrier_vec n" using A unfolding cols_def by auto
  have "lincomb f S  carrier_vec (dim_row A)" if "finite S" and S: "S  set (cols A)" for f S 
    using lincomb_closed A
    by (metis (full_types) S carrier_matD(1) cols_dim lincomb_closed subsetCE subsetI)
  moreover have "xcarrier_vec (dim_col A). A *v x = lincomb f S" 
    if fin_S: "finite S" and S: "S  set (cols A)" for f S
  proof -    
    let ?g = "(λv. if v  S then f v else 0)"
    let ?g' = "(λi. if j<i. ?ws ! i = ?ws ! j then 0 else ?g (?ws ! i))"
    let ?Z = "set ?ws - S"
    have union: "set ?ws = S  ?Z" using S by auto
    have inter: "S  ?Z = {}" by auto    
    have "lincomb f S = lincomb ?g S" by (rule lincomb_cong, insert set_cols_in A S, auto)
    also have "... = lincomb ?g (S  ?Z)" 
      by (rule lincomb_clean[symmetric],insert set_cols_in A S fin_S, auto)
    also have "... = lincomb ?g (set ?ws)" using union by auto
    also have "... = lincomb_list ?g' ?ws" 
      by (rule lincomb_as_lincomb_list[OF set_cols_in])
    also have "... = mat_of_cols n ?ws *v vec (length ?ws) ?g'" 
      by (rule lincomb_list_as_mat_mult, insert set_cols_in A, auto)
    also have "... = A *v (vec (length ?ws) ?g')" using mat_of_cols_cols A by auto
    finally show ?thesis by auto
  qed 
  moreover have "f S. A *v x = lincomb f S  finite S  S  set (cols A)" 
    if Ax: "A *v x  carrier_vec (dim_row A)" and x: "x  carrier_vec (dim_col A)" for x 
  proof -
    let ?c = "λi. x $ i"
    have x_vec: "vec (length ?ws) ?c = x" using x by auto
    have "A *v x = mat_of_cols n ?ws *v vec (length ?ws) ?c" using mat_of_cols_cols A x_vec by auto
    also have "... = lincomb_list ?c ?ws" 
      by (rule lincomb_list_as_mat_mult[symmetric], insert set_cols_in A, auto)
    also have "... = lincomb (mk_coeff ?ws ?c) (set ?ws)" 
      by (rule lincomb_list_as_lincomb, insert set_cols_in A, auto)
    finally show ?thesis by auto
  qed
  ultimately show ?thesis unfolding col_space_def span_def by auto
qed

lemma vector_space_row_space: 
  assumes A: "A  carrier_mat nr n"
  shows "vectorspace class_ring (vs (row_space A))"
proof -
  have fin: "finite (set (rows A))" by auto
  have s: "set (rows A)  carrier V" using A unfolding rows_def by auto
  have "span_vs (set (rows A)) = vs (span (set (rows A)))" by auto
  moreover have "vectorspace class_ring (span_vs (set (rows A)))" 
    using fin s span_is_subspace subspace_def subspace_is_vs by simp
  ultimately show ?thesis unfolding row_space_def by auto
qed

lemma row_space_eq:
  assumes A: "A  carrier_mat nr n"
  shows "row_space A = {wcarrier_vec (dim_col A). ycarrier_vec (dim_row A). AT *v y = w}" 
  using A col_space_eq unfolding row_space_eq_col_space_transpose by auto

lemma row_space_is_preserved:
  assumes inv_P: "invertible_mat P" and P: "P  carrier_mat m m" and A: "A  carrier_mat m n"
  shows "row_space (P*A) = row_space A"
proof -
  have At: "AT  carrier_mat n m" using A by auto
  have Pt: "PT  carrier_mat m m" using P by auto
  have PA: "P*A  carrier_mat m n" using P A by auto
  have "w  row_space A" if w: "w  row_space (P*A)" for w
  proof -
    have w_carrier: "w  carrier_vec (dim_col (P*A))"
      using w mult_carrier_mat[OF P A] row_space_eq by auto     
    from that and this obtain y where y: "y  carrier_vec (dim_row (P * A))" 
      and w_By: "w = (P*A)T *v y" unfolding row_space_eq[OF PA] by blast
    have ym: "y  carrier_vec m" using y Pt by auto
    have "w=((P*A)T) *v y" using w_By .
    also have "... = (AT * PT) *v y" using transpose_mult[OF P A] by auto
    also have "... = AT *v (PT *v y)" by (rule assoc_mult_mat_vec[OF At Pt], insert Pt y, auto)
    finally show "w  row_space A" unfolding row_space_eq[OF A] using At Pt ym by auto
  qed
    moreover have "w  row_space (P*A)" if w: "w  row_space A" for w
    proof -
      have w_carrier: "w  carrier_vec (dim_col A)" using w A unfolding row_space_eq[OF A] by auto
      obtain P' where PP': "inverts_mat P P'" and P'P: "inverts_mat P' P" 
        using inv_P P unfolding invertible_mat_def by blast
      have P': "P'  carrier_mat m m" using PP' P'P P unfolding inverts_mat_def 
        by (metis carrier_matD(1) carrier_matD(2) carrier_mat_triv index_mult_mat(3) index_one_mat(3))        
      from that obtain y where y: "y  carrier_vec (dim_row A)" and 
        w_Ay: "w = AT *v y" unfolding row_space_eq[OF A] by blast
      have Py: "(P'T *v y)  carrier_vec m" using P' y A by auto
      have "w = AT *v y" using w_Ay .
      also have "... = ((P' * P)*A)T *v y" 
        using P'P left_mult_one_mat A P' unfolding inverts_mat_def by auto
      also have "... = ((P' * (P*A))T) *v y" using assoc_mult_mat_vec P' P A by auto
      also have "... = ((P*A)T * P'T) *v y" using transpose_mult P A P' mult_carrier_mat by metis        
      also have "... = (P*A)T *v (P'T *v y)" 
        using assoc_mult_mat_vec A P P' y mult_carrier_mat
        by (smt carrier_matD(1) transpose_carrier_mat)
      finally show "w  row_space (P*A)"
        unfolding row_space_eq[OF PA] 
        using Py w_carrier A P by fastforce
    qed
  ultimately show ?thesis by auto
qed

end

context vec_module begin

lemma R_sumlist[simp]: "R.sumlist = sum_list" 
proof (intro ext) 
  fix xs
  show "R.sumlist xs = sum_list xs" by (induct xs, auto)
qed

lemma sumlist_dim: assumes " x. x  set xs  x  carrier_vec n"
  shows "dim_vec (sumlist xs) = n"
  using sumlist_carrier assms
  by fastforce
    
lemma sumlist_vec_index: assumes " x. x  set xs  x  carrier_vec n"
  and "i < n" 
shows "sumlist xs $ i = sum_list (map (λ x. x $ i) xs)" 
  unfolding M.sumlist_def using assms(1) proof(induct xs)
  case (Cons a xs)
  hence cond:" x. x  set xs  x  carrier_vec n" by auto
  from Cons(1)[OF cond] have IH:"foldr (+) xs (0v n) $ i = (xxs. x $ i)" by auto
  have "(a + foldr (+) xs (0v n)) $ i = a $ i + (xxs. x $ i)" 
    apply(subst index_add_vec) unfolding IH
    using sumlist_dim[OF cond,unfolded M.sumlist_def] assms by auto
  then show ?case by auto next
  case Nil thus ?case using assms by auto
qed
 
lemma scalar_prod_left_sum_distrib: 
  assumes vs: " v. v  set vvs  v  carrier_vec n" and w: "w  carrier_vec n" 
  shows "sumlist vvs  w = sum_list (map (λ v. v  w) vvs)"
  using vs
proof (induct vvs)
  case (Cons v vs)
  from Cons have v: "v  carrier_vec n" and vs: "sumlist vs  carrier_vec n" 
    by (auto intro!: sumlist_carrier)
  have "sumlist (v # vs)  w = sumlist ([v] @ vs)  w " by auto
  also have " = (v + sumlist vs)  w" 
    by (subst sumlist_append, insert Cons v vs, auto)
  also have " = v  w + (sumlist vs  w)" 
    by (rule add_scalar_prod_distrib[OF v vs w])
  finally show ?case using Cons by auto
qed (insert w, auto)   

lemma scalar_prod_right_sum_distrib: 
  assumes vs: " v. v  set vvs  v  carrier_vec n" and w: "w  carrier_vec n" 
  shows "w  sumlist vvs = sum_list (map (λ v. w  v) vvs)"
  by (subst comm_scalar_prod[OF w sumlist_carrier], insert vs w, force,
  subst scalar_prod_left_sum_distrib[OF vs w], force,
  rule arg_cong[of _ _ sum_list], rule nth_equalityI, 
  auto simp: set_conv_nth intro!: comm_scalar_prod)

lemma lincomb_list_add_vec_2: assumes us: "set us  carrier_vec n" 
  and x: "x = lincomb_list lc (us [i := us ! i + c v us ! j])"
  and i: "j < length us" "i < length us" "i  j" 
shows "x = lincomb_list (lc (j := lc j + lc i * c)) us" (is "_ = ?x")
proof -
  let ?xx = "lc j + lc i * c" 
  let ?i = "us ! i" 
  let ?j = "us ! j" 
  let ?v = "?i + c v ?j" 
  let ?ws = "us [i := us ! i + c v us ! j]"
  from us have usk: "k < length us  us ! k  carrier_vec n" for k by auto
  from usk i have ij: "?i  carrier_vec n" "?j  carrier_vec n" by auto
  hence v: "c v ?j  carrier_vec n" "?v  carrier_vec n" by auto
  with us have ws: "set ?ws  carrier_vec n" unfolding set_conv_nth using i 
    by (auto, rename_tac k, case_tac "k = i", auto)
  from us have us': "wset us. dim_vec w = n" by auto 
  from ws have ws': "wset ?ws. dim_vec w = n" by auto 
  have mset: "mset_set {0..<length us} = {#i#} + {#j#} + (mset_set ({0..<length us} - {i,j}))" 
    by (rule multiset_eqI, insert i, auto, rename_tac x, case_tac "x  {0 ..< length us}", auto)
  define M2 where "M2 = M.summset
      {#lc ia v ?ws ! ia. ia ∈# mset_set ({0..<length us} - {i, j})#}" 
  define M1 where "M1 = M.summset {#(if i = j then ?xx else lc i) v us ! i. i ∈# mset_set ({0..<length us} - {i, j})#}" 
  have M1: "M1  carrier_vec n" unfolding M1_def using usk by fastforce
  have M2: "M1 = M2" unfolding M2_def M1_def
    by (rule arg_cong[of _ _ M.summset], rule multiset.map_cong0, insert i usk, auto) 
  have x1: "x = lc j v ?j + (lc i v ?i + lc i v (c v ?j) + M1)" 
    unfolding x lincomb_list_def M2 M2_def
    apply (subst sumlist_as_summset, (insert us ws i v ij, auto simp: set_conv_nth)[1], insert i ij v us ws usk, 
      simp add: mset smult_add_distrib_vec[OF ij(1) v(1)])
    by (subst M.summset_add_mset, auto)+
  have x2: "?x = ?xx v ?j + (lc i v ?i + M1)"
    unfolding x lincomb_list_def M1_def
    apply (subst sumlist_as_summset, (insert us ws i v ij, auto simp: set_conv_nth)[1], insert i ij v us ws usk, 
      simp add: mset smult_add_distrib_vec[OF ij(1) v(1)])
    by (subst M.summset_add_mset, auto)+
  show ?thesis unfolding x1 x2 using M1 ij
    by (intro eq_vecI, auto simp: field_simps)
qed

lemma lincomb_list_add_vec_1: assumes us: "set us  carrier_vec n" 
  and x: "x = lincomb_list lc us"
  and i: "j < length us" "i < length us" "i  j" 
shows "x = lincomb_list (lc (j := lc j - lc i * c)) (us [i := us ! i + c v us ! j])" (is "_ = ?x")
proof -
  let ?i = "us ! i" 
  let ?j = "us ! j" 
  let ?v = "?i + c v ?j" 
  let ?ws = "us [i := us ! i + c v us ! j]"
  from us have usk: "k < length us  us ! k  carrier_vec n" for k by auto
  from usk i have ij: "?i  carrier_vec n" "?j  carrier_vec n" by auto
  hence v: "c v ?j  carrier_vec n" "?v  carrier_vec n" by auto
  with us have ws: "set ?ws  carrier_vec n" unfolding set_conv_nth using i 
    by (auto, rename_tac k, case_tac "k = i", auto)
  from us have us': "wset us. dim_vec w = n" by auto 
  from ws have ws': "wset ?ws. dim_vec w = n" by auto 
  have mset: "mset_set {0..<length us} = {#i#} + {#j#} + (mset_set ({0..<length us} - {i,j}))" 
    by (rule multiset_eqI, insert i, auto, rename_tac x, case_tac "x  {0 ..< length us}", auto)
  define M2 where "M2 = M.summset
      {#(if ia = j then lc j - lc i * c else lc ia) v ?ws ! ia
      . ia ∈# mset_set ({0..<length us} - {i, j})#}" 
  define M1 where "M1 = M.summset {#lc i v us ! i. i ∈# mset_set ({0..<length us} - {i, j})#}" 
  have M1: "M1  carrier_vec n" unfolding M1_def using usk by fastforce
  have M2: "M1 = M2" unfolding M2_def M1_def
    by (rule arg_cong[of _ _ M.summset], rule multiset.map_cong0, insert i usk, auto) 
  have x1: "x = lc j v ?j + (lc i v ?i + M1)" 
    unfolding x lincomb_list_def M1_def
    apply (subst sumlist_as_summset, (insert us ws i v ij, auto simp: set_conv_nth)[1], insert i ij v us ws usk, 
      simp add: mset smult_add_distrib_vec[OF ij(1) v(1)])
    by (subst M.summset_add_mset, auto)+
  have x2: "?x = (lc j - lc i * c) v ?j + (lc i v ?i + lc i v (c v ?j) + M1)"
    unfolding x lincomb_list_def M2 M2_def
    apply (subst sumlist_as_summset, (insert us ws i v ij, auto simp: set_conv_nth)[1], insert i ij v us ws usk, 
      simp add: mset smult_add_distrib_vec[OF ij(1) v(1)])
    by (subst M.summset_add_mset, auto)+
  show ?thesis unfolding x1 x2 using M1 ij
    by (intro eq_vecI, auto simp: field_simps)
qed

end

context vec_space
begin
lemma add_vec_span: assumes us: "set us  carrier_vec n" 
  and i: "j < length us" "i < length us" "i  j" 
shows "span (set us) = span (set (us [i := us ! i + c v us ! j]))" (is "_ = span (set ?ws)")
proof -
  let ?i = "us ! i" 
  let ?j = "us ! j" 
  let ?v = "?i + c v ?j" 
  from us i have ij: "?i  carrier_vec n" "?j  carrier_vec n" by auto
  hence v: "?v  carrier_vec n" by auto
  with us have ws: "set ?ws  carrier_vec n" unfolding set_conv_nth using i 
    by (auto, rename_tac k, case_tac "k = i", auto)
  have "span (set us) = span_list us" unfolding span_list_as_span[OF us] ..
  also have " = span_list ?ws"
  proof -
    {
      fix x
      assume "x  span_list us" 
      then obtain lc where "x = lincomb_list lc us" by (metis in_span_listE)
      from lincomb_list_add_vec_1[OF us this i, of c]
      have "x  span_list ?ws" unfolding span_list_def by auto
    }
    moreover
    {
      fix x
      assume "x  span_list ?ws" 
      then obtain lc where "x = lincomb_list lc ?ws" by (metis in_span_listE)
      from lincomb_list_add_vec_2[OF us this i]
      have "x  span_list us" unfolding span_list_def by auto
    }
    ultimately show ?thesis by blast
  qed
  also have " = span (set ?ws)" unfolding span_list_as_span[OF ws] ..
  finally show ?thesis .
qed

lemma prod_in_span[intro!]:
  assumes "b  carrier_vec n" "S  carrier_vec n" "a = 0  b  span S"
  shows "a v b  span S"
proof(cases "a = 0")
  case True
  then show ?thesis by (auto simp:lmult_0[OF assms(1)] span_zero)
next
  case False with assms have "b  span S" by auto
  from this[THEN in_spanE]
  obtain aa A where a[intro!]: "b = lincomb aa A" "finite A" "A  S" by auto
  hence [intro!]:"(λv. aa v v v)  A  carrier_vec n" using assms by auto 
  show ?thesis proof
    show "a v b = lincomb (λ v. a * aa v) A" using a(1) unfolding lincomb_def smult_smult_assoc[symmetric]
      by(subst finsum_smult[symmetric]) force+
  qed auto
qed

lemma det_nonzero_congruence:
  assumes eq:"A * M = B * M" and det:"det (M::'a mat)  0"
  and M: "M  carrier_mat n n" and carr:"A  carrier_mat n n" "B  carrier_mat n n"
  shows "A = B"
proof -
  have "1m n  carrier_mat n n" by auto
  from det_non_zero_imp_unit[OF M det] gauss_jordan_check_invertable[OF M this]
  have gj_fst:"(fst (gauss_jordan M (1m n)) = 1m n)" by metis
  define Mi where "Mi = snd (gauss_jordan M (1m n))"
  with gj_fst have gj:"gauss_jordan M (1m n) = (1m n, Mi)"
    unfolding fst_def snd_def by (auto split:prod.split)
  from gauss_jordan_compute_inverse(1,3)[OF M gj]
  have Mi: "Mi  carrier_mat n n" and is1:"M * Mi = 1m n" by metis+
  from arg_cong[OF eq, of "λ M. M * Mi"]
  show "A = B" unfolding carr[THEN assoc_mult_mat[OF _ M Mi]] is1 carr[THEN right_mult_one_mat].
qed

lemma mat_of_rows_mult_as_finsum:
  assumes "v  carrier_vec (length lst)" " i. i < length lst  lst ! i  carrier_vec n"
  defines "f l  sum (λ i. if l = lst ! i then v $ i else 0) {0..<length lst}"
  shows mat_of_cols_mult_as_finsum:"mat_of_cols n lst *v v = lincomb f (set lst)"
proof -
  from assms have " i < length lst. lst ! i  carrier_vec n" by blast
  note an = all_nth_imp_all_set[OF this] hence slc:"set lst  carrier_vec n" by auto
  hence dn [simp]:" x. x  set lst  dim_vec x = n" by auto
  have dl [simp]:"dim_vec (lincomb f (set lst)) = n" using an by (intro lincomb_dim,auto)
  show ?thesis proof
    show "dim_vec (mat_of_cols n lst *v v) = dim_vec (lincomb f (set lst))" using assms(1,2) by auto
    fix i assume i:"i < dim_vec (lincomb f (set lst))" hence i':"i < n" by auto
    with an have fcarr:"(λv. f v v v)  set lst  carrier_vec n" by auto
    from i' have "(mat_of_cols n lst *v v) $ i = row (mat_of_cols n lst) i  v" by auto
    also have " = (ia = 0..<dim_vec v. lst ! ia $ i * v $ ia)"
      unfolding mat_of_cols_def row_def scalar_prod_def
      apply(rule sum.cong[OF refl]) using i an assms(1) by auto
    also have " = (ia = 0..<length lst. lst ! ia $ i * v $ ia)" using assms(1) by auto
    also have " = (xset lst. f x * x $ i)"
      unfolding f_def sum_distrib_right apply (subst sum.swap)
      apply(rule sum.cong[OF refl])
      unfolding if_distrib if_distribR mult_zero_left sum.delta[OF finite_set] by auto
    also have " = (xset lst. (f x v x) $ i)"
      apply(rule sum.cong[OF refl],subst index_smult_vec) using i slc by auto
    also have " = (Vvset lst. f v v v) $ i"
      unfolding finsum_index[OF i' fcarr slc] by auto
    finally show "(mat_of_cols n lst *v v) $ i = lincomb f (set lst) $ i"
      by (auto simp:lincomb_def)
  qed
qed

end

end

Theory Gram_Schmidt

(*  
    Author:      René Thiemann 
                 Akihisa Yamada
    License:     BSD
*)
section ‹Gram-Schmidt Orthogonalization›

text ‹
  This theory provides the Gram-Schmidt orthogonalization algorithm,
  that takes the conjugate operation into account. It works over fields
  like the rational, real, or complex numbers. 
›

theory Gram_Schmidt
imports 
  VS_Connect 
  Missing_VectorSpace
  Conjugate
begin

subsection ‹Orthogonality with Conjugates›

definition "corthogonal vs 
    i < length vs. j < length vs. vs ! i ∙c vs ! j = 0  i  j"

lemma corthogonalD[elim]:
  "corthogonal vs  i < length vs  j < length vs 
   vs ! i ∙c vs ! j = 0  i  j"
  unfolding corthogonal_def by auto

lemma corthogonalI[intro]:
  "(i j. i < length vs  j < length vs  vs ! i ∙c vs ! j = 0  i  j) 
   corthogonal vs"
  unfolding corthogonal_def by auto

lemma corthogonal_distinct: "corthogonal us  distinct us"
proof (induct us)
  case (Cons u us)
    have "u  set us"
    proof
      assume "u : set us"
      then obtain j where uj: "u = us!j" and j: "j < length us"
        using in_set_conv_nth by metis
      hence j': "j+1 < length (u#us)" by auto
      have "u ∙c us!j = 0"
        using corthogonalD[OF Cons(2) _ j',of 0] by auto
      hence "u ∙c u = 0" using uj by simp
      thus False using corthogonalD[OF Cons(2),of 0 0] by auto
    qed
    moreover have "distinct us"
    proof (rule Cons(1),intro corthogonalI)
      fix i j assume "i < length (us)" "j < length (us)"
      hence len: "i+1 < length (u#us)" "j+1 < length (u#us)" by auto
      show "(us!i ∙c us!j = 0) = (ij)"
        using corthogonalD[OF Cons(2) len] by simp
    qed
    ultimately show ?case by simp
qed simp

lemma corthogonal_sort:
  assumes dist': "distinct us'"
      and mem: "set us = set us'"
  shows "corthogonal us  corthogonal us'"
proof
  assume orth: "corthogonal us"
  hence dist: "distinct us" using corthogonal_distinct by auto
  fix i' j' assume i': "i' < length us'" and j': "j' < length us'"
  obtain i where ii': "us!i = us'!i'" and i: "i < length us"
    using mem i' in_set_conv_nth by metis
  obtain j where jj': "us!j = us'!j'" and j: "j < length us"
    using mem j' in_set_conv_nth by metis
  from corthogonalD[OF orth i j]
  have "(us!i ∙c us!j = 0) = (i  j)".
  hence "(us'!i' ∙c us'!j' = 0) = (i  j)" using ii' jj' by auto
  also have "... = (us!i  us!j)" using nth_eq_iff_index_eq dist i j by auto
  also have "... = (us'!i'  us'!j')" using ii' jj' by auto
  also have "... = (i'  j')" using nth_eq_iff_index_eq dist' i' j' by auto
  finally show "(us'!i' ∙c us'!j' = 0) = (i'  j')".
qed

subsection‹The Algorithm›

fun adjuster :: "nat  'a :: conjugatable_field vec  'a vec list  'a vec"
  where "adjuster n w [] = 0v n"
    |  "adjuster n w (u#us) = -(w ∙c u)/(u ∙c u) v u + adjuster n w us"

text ‹
  The following formulation is easier to analyze,
  but outputs of the subroutine should be properly reversed.
›

fun gram_schmidt_sub
  where "gram_schmidt_sub n us [] = us"
  | "gram_schmidt_sub n us (w # ws) =
     gram_schmidt_sub n ((adjuster n w us + w) # us) ws"

definition gram_schmidt :: "nat  'a :: conjugatable_field vec list  'a vec list"
  where "gram_schmidt n ws = rev (gram_schmidt_sub n [] ws)"

text ‹
  The following formulation requires no reversal.
›

fun gram_schmidt_sub2
  where "gram_schmidt_sub2 n us [] = []"
  | "gram_schmidt_sub2 n us (w # ws) =
     (let u = adjuster n w us + w in
      u # gram_schmidt_sub2 n (u # us) ws)"

lemma gram_schmidt_sub_eq:
  "rev (gram_schmidt_sub n us ws) = rev us @ gram_schmidt_sub2 n us ws"
  by (induct ws arbitrary:us, auto simp:Let_def)

lemma gram_schmidt_code[code]:
  "gram_schmidt n ws = gram_schmidt_sub2 n [] ws"
  unfolding gram_schmidt_def
  apply(subst gram_schmidt_sub_eq) by simp

subsection ‹Properties of the Algorithms›

locale cof_vec_space = vec_space f_ty for
  f_ty :: "'a :: conjugatable_ordered_field itself"
begin

lemma adjuster_finsum:
  assumes U: "set us  carrier_vec n"
    and dist: "distinct (us :: 'a vec list)"
  shows "adjuster n w us = finsum V (λu. -(w ∙c u)/(u ∙c u) v u) (set us)"
  using assms
proof (induct us)
  case Cons show ?case unfolding set_simps
  by (subst finsum_insert[OF finite_set], insert Cons, auto)
qed simp

lemma adjuster_lincomb:
  assumes w: "(w :: 'a vec) : carrier_vec n"
    and us: "set (us :: 'a vec list)  carrier_vec n"
    and dist: "distinct us"
  shows "adjuster n w us = lincomb (λu. -(w ∙c u)/(u ∙c u)) (set us)"
    (is "_ = lincomb ?a _")
  using us dist unfolding lincomb_def
proof (induct us)
  case (Cons u us)
    let ?f = "λu. ?a u v u"
    have "?f : (set us)  carrier_vec n" and "?f u : carrier_vec n" using w Cons by auto
    moreover have "u  set us" using Cons by auto
    ultimately show ?case
      unfolding adjuster.simps
      unfolding set_simps
      using finsum_insert[OF finite_set] Cons by auto
qed simp

lemma adjuster_in_span:
  assumes w: "(w :: 'a vec) : carrier_vec n"
    and us: "set (us :: 'a vec list)  carrier_vec n"
    and dist: "distinct us"
  shows "adjuster n w us : span (set us)"
  using adjuster_lincomb[OF assms]
  unfolding finite_span[OF finite_set us] by auto

lemma adjuster_carrier[simp]:
  assumes w: "(w :: 'a vec) : carrier_vec n"
    and us: "set (us :: 'a vec list)  carrier_vec n"
    and dist: "distinct us"
  shows "adjuster n w us : carrier_vec n"
  using adjuster_in_span span_closed assms by auto

lemma adjust_not_in_span:
  assumes w[simp]: "(w :: 'a vec) : carrier_vec n"
    and us: "set (us :: 'a vec list)  carrier_vec n"
    and dist: "distinct us"
    and ind: "w  span (set us)"
  shows "adjuster n w us + w  span (set us)"
  using span_add[OF us adjuster_in_span[OF w us dist] w]
  using comm_add_vec ind by auto

lemma adjust_not_mem:
  assumes w[simp]: "(w :: 'a vec) : carrier_vec n"
    and us: "set (us :: 'a vec list)  carrier_vec n"
    and dist: "distinct us"
    and ind: "w  span (set us)"
  shows "adjuster n w us + w  set us"
  using adjust_not_in_span[OF assms] span_mem[OF us] by auto

lemma adjust_in_span:
  assumes w[simp]: "(w :: 'a vec) : carrier_vec n"
    and us: "set (us :: 'a vec list)  carrier_vec n"
    and dist: "distinct us"
  shows "adjuster n w us + w : span (insert w (set us))" (is "?v + _ : span ?U")
proof -
  let ?a = "λu. -(w ∙c u)/(u ∙c u)"
  have "?v = lincomb ?a (set us)" using adjuster_lincomb[OF assms].
  hence vU: "?v : span (set us)" unfolding finite_span[OF finite_set us] by auto
  hence v[simp]: "?v : carrier_vec n" using span_closed[OF us] by auto
  have vU': "?v : span ?U" using vU span_is_monotone[OF subset_insertI] by auto

  have "{w}  ?U" by simp
  from span_is_monotone[OF this]
  have wU': "w : span ?U" using span_self[OF w] by auto

  have "?U  carrier_vec n" using us w by simp
  from span_add[OF this wU' v] vU' comm_add_vec[OF w]
  show ?thesis by simp
qed

lemma adjust_not_lindep:
  assumes w[simp]: "(w :: 'a vec) : carrier_vec n"
    and us: "set (us :: 'a vec list)  carrier_vec n"
    and dist: "distinct us"
    and wus: "w  span (set us)"
    and ind: "~ lin_dep (set us)"
  shows "~ lin_dep (insert (adjuster n w us + w) (set us))"
    (is "~ _ (insert ?v _)")
proof -
  have v: "?v : carrier_vec n" using assms by auto
  have "?v  span (set us)"
    using adjust_not_in_span[OF w us dist wus]
    using comm_add_vec[OF adjuster_carrier[OF w us dist] w] by auto
  thus ?thesis
    using lin_dep_iff_in_span[OF us ind v] adjust_not_mem[OF w us dist wus] by auto
qed

lemma adjust_preserves_span:
  assumes w[simp]: "(w :: 'a vec) : carrier_vec n"
    and us: "set (us :: 'a vec list)  carrier_vec n"
    and dist: "distinct us"
  shows "w : span (set us)  adjuster n w us + w : span (set us)"
    (is "_  ?v + _ : _")
proof -
  have "?v : span (set us)"
    using adjuster_lincomb[OF assms]
    unfolding finite_span[OF finite_set us] by auto
  hence [simp]: "?v : carrier_vec n" using span_closed[OF us] by auto
  show ?thesis
    using span_add[OF us adjuster_in_span[OF w us] w] comm_add_vec[OF w] dist
    by auto
qed

lemma in_span_adjust:
  assumes w[simp]: "(w :: 'a vec) : carrier_vec n"
    and us: "set (us :: 'a vec list)  carrier_vec n"
    and dist: "distinct us"
  shows "w : span (insert (adjuster n w us + w) (set us))"
    (is "_ : span (insert ?v _)")
proof -
  have v: "?v : carrier_vec n" using assms by auto
  have a[simp]: "adjuster n w us : carrier_vec n"
   and neg: "- adjuster n w us : carrier_vec n" using assms by auto
  hence vU: "insert ?v (set us)  carrier_vec n" using us by auto
  have aS: "adjuster n w us : span (insert ?v (set us))"
    using adjuster_in_span[OF w us] span_is_monotone[OF subset_insertI] dist
    by auto
  have negS: "- adjuster n w us : span (insert ?v (set us))"
    using span_neg[OF vU aS] us by simp
  have [simp]:"- adjuster n w us + (adjuster n w us + w) = w"
    unfolding a_assoc[OF neg a w,symmetric] by simp
  have "{?v}  insert ?v (set us)" by simp
  from span_is_monotone[OF this]
  have vS: "?v : span (insert ?v (set us))" using span_self[OF v] by auto
  thus ?thesis using span_add[OF vU negS v] by auto
qed

lemma adjust_zero:
  assumes U: "set (us :: 'a vec list)  carrier_vec n"
    and orth: "corthogonal us"
    and w[simp]: "w : carrier_vec n"
    and i: "i < length us"
  shows "(adjuster n w us + w) ∙c us!i = 0"
proof -
  define u where "u = us!i"
  have u[simp]: "u : carrier_vec n" using i U u_def by auto
  hence cu[simp]: "conjugate u : carrier_vec n" by auto
  have uU: "u : set us" using i u_def by auto
  let ?g = "λu'::'a vec. (-(w ∙c u')/(u' ∙c u') v u')"
  have g: "?g : set us  carrier_vec n" using w U by auto
  hence carrier: "finsum V ?g (set us) : carrier_vec n" by simp
  let ?f = "λu'. ?g u' ∙c u"
  let ?U = "set us - {u}"
  { fix u' assume u': "(u'::'a vec) : carrier_vec n"
    have [simp]: "dim_vec u = n" by auto
    have "?f u' = (- (w ∙c u') / (u' ∙c u')) * (u' ∙c u)"
      using scalar_prod_smult_left[of "u'" "conjugate u"]
      unfolding carrier_vecD[OF u] carrier_vecD[OF u'] by auto
  } note conv = this
  have "?f : ?U  {0}"
  proof (intro Pi_I)
    fix u' assume u'Uu: "u' : set us - {u}"
    hence u'U: "u' : set us" by auto
    hence u'[simp]: "u' : carrier_vec n" using U by auto
    obtain j where j: "j < length us" and u'j: "u' = us ! j"
      using u'U in_set_conv_nth by metis
    have "i  j" using u'Uu u'j u_def by auto
    hence "u' ∙c u = 0"
      unfolding u'j using corthogonalD[OF orth j i] u_def by auto
    hence "?f u' = 0" using mult_zero_right conv[OF u'] by auto
    thus "?f u' : {0}" by auto
  qed
  hence "restrict ?f ?U = restrict (λu. 0) ?U" by force
  hence "sum ?f ?U = sum (λu. 0) ?U"
    by (intro R.finsum_restrict, auto)
  hence fU'0: "sum ?f ?U = 0" by auto
  have uU': "u  ?U" by auto
  have "set us = insert u ?U"
    using insert_Diff_single uU by auto
  hence "sum ?f (set us) = ?f u + sum ?f ?U"
    using R.finsum_insert[OF _ uU'] by auto
  also have "... = ?f u" using fU'0 by auto
  also have "... = - (w ∙c u) / (u ∙c u) * (u ∙c u)"
    using conv[OF u] by auto
  finally have main: "sum ?f (set us) = - (w ∙c u)"
    unfolding u_def
    by (simp add: i orth corthogonalD)
  show ?thesis
    unfolding u_def[symmetric]
    unfolding adjuster_finsum[OF U corthogonal_distinct[OF orth]]
    unfolding add_scalar_prod_distrib[OF carrier w cu]
    unfolding finsum_scalar_prod_sum[OF g cu]
    unfolding main
    unfolding comm_scalar_prod[OF cu w]
    using left_minus by auto
qed

lemma adjust_nonzero:
  assumes U: "set (us :: 'a vec list)  carrier_vec n"
    and dist: "distinct us"
    and w[simp]: "w : carrier_vec n"
    and wsU: "w  span (set us)"
  shows "adjuster n w us + w  0v n" (is "?a + _  _")
proof
  have [simp]: "?a : carrier_vec n" using U dist by auto
  have [simp]: "- ?a : carrier_vec n" by auto
  have [simp]: "?a + w : carrier_vec n" by auto
  assume "?a + w = 0v n"
  hence "- ?a = - ?a + (?a + w)" by auto
  also have "... = (- ?a + ?a) + w" apply(subst a_assoc) by auto
  also have "- ?a + ?a = 0v n" using r_neg[OF w] unfolding vec_neg[OF w] by auto
  finally have "- ?a = w" by auto
  moreover have "- ?a : span (set us)"
    using span_neg[OF U adjuster_in_span[OF w U dist]] by auto
  ultimately show "False" using wsU by auto
qed

lemma adjust_orthogonal:
  assumes U: "set (us :: 'a vec list)  carrier_vec n"
    and orth: "corthogonal us"
    and w[simp]: "w : carrier_vec n"
    and wsU: "w  span (set us)"
  shows "corthogonal ((adjuster n w us + w) # us)"
    (is "corthogonal (?aw # _)")
proof
  have dist: "distinct us" using corthogonal_distinct orth by auto
  have aw[simp]: "?aw : carrier_vec n" using U dist by auto
  note adjust_nonzero[OF U dist w] wsU
  hence aw0: "?aw ∙c ?aw  0" using conjugate_square_eq_0_vec[OF aw] by auto
  fix i j assume i: "i < length (?aw # us)" and j: "j < length (?aw # us)"
  show "((?aw # us) ! i ∙c (?aw # us) ! j = 0) = (i  j)"
  proof (cases "i = 0")
    case True note i0 = this
      show ?thesis
      proof (cases "j = 0")
        case True show ?thesis unfolding True i0 using aw0 by auto
        next case False
          define j' where "j' = j-1"
          hence jfold: "j = j'+1" using False by auto
          hence j': "j' < length us" using j by auto
          show ?thesis unfolding i0 jfold
            using adjust_zero[OF U orth w j'] by auto
      qed
    next case False
      define i' where "i' = i-1"
      hence ifold: "i = i'+1" using False by auto
      hence i': "i' < length us" using i by auto
      have [simp]: "us ! i' : carrier_vec n" using U i' by auto
      hence cu': "conjugate (us ! i') : carrier_vec n" by auto
      show ?thesis
      proof (cases "j = 0")
        case True
          { assume "?aw ∙c us ! i' = 0"
            hence "conjugate (?aw ∙c us ! i') = 0" using conjugate_zero by auto
            hence "conjugate ?aw  us ! i' = 0"
              using conjugate_sprod_vec[OF aw cu'] by auto
          }
          thus ?thesis unfolding True ifold
          using adjust_zero[OF U orth w i']
          by (subst comm_scalar_prod[of _ n], auto)
        next case False
          define j' where "j' = j-1"
          hence jfold: "j = j'+1" using False by auto
          hence j': "j' < length us" using j by auto
          show ?thesis
            unfolding ifold jfold
            using orth i' j' by (auto simp: corthogonalD)
     qed
  qed
qed

lemma gram_schmidt_sub_span:
  assumes w[simp]: "w : carrier_vec n"
    and us: "set us  carrier_vec n"
    and dist: "distinct us"
  shows "span (set ((adjuster n w us + w) # us)) = span (set (w # us))"
  (is "span (set (?v # _)) = span ?wU")
proof (cases "w : span (set us)")
  case True
    hence "?v : span (set us)"
      using adjust_preserves_span[OF assms] by auto
    thus ?thesis using already_in_span[OF us] True by auto next
  case False show ?thesis
    proof
      have wU: "?wU  carrier_vec n" using us by simp 
      have vswU: "?v : span ?wU" using adjust_in_span[OF assms] by auto
      hence v: "?v : carrier_vec n" using span_closed[OF wU] by auto
      have wsvU: "w : span (insert ?v (set us))" using in_span_adjust[OF assms].
      show "span ?wU  span (set (?v # us))"
        using span_swap[OF finite_set us w False v wsvU] by auto
      have "?v  span (set us)"
        using False adjust_preserves_span[OF assms] by auto
      thus "span (set (?v # us))  span ?wU"
        using span_swap[OF finite_set us v _ w] vswU by auto
    qed
qed

lemma gram_schmidt_sub_result:
  assumes "gram_schmidt_sub n us ws = us'"
    and "set ws  carrier_vec n"
    and "set us  carrier_vec n"
    and "distinct (us @ ws)"
    and "~ lin_dep (set (us @ ws))"
    and "corthogonal us"
  shows "set us'  carrier_vec n 
         distinct us' 
         corthogonal us' 
         span (set (us @ ws)) = span (set us')  length us' = length us + length ws"  
  using assms
proof (induct ws arbitrary: us us')
case (Cons w ws)
  let ?v = "adjuster n w us"
  have wW[simp]: "set (w#ws)  carrier_vec n" using Cons by simp
  hence W[simp]: "set ws  carrier_vec n"
   and w[simp]: "w : carrier_vec n" by auto
  have U[simp]: "set us  carrier_vec n" using Cons by simp
  have UW: "set (us@ws)  carrier_vec n" by simp
  have wU: "set (w#us)  carrier_vec n" by simp
  have dist: "distinct (us @ w # ws)" using Cons by simp
  hence dist_U: "distinct us"
    and dist_W: "distinct ws"
    and dist_UW: "distinct (us @ ws)"
    and w_U: "w  set us"
    and w_W: "w  set ws"
    and w_UW: "w  set (us @ ws)" by auto
  have ind: "~ lin_dep (set (us @ w # ws))" using Cons by simp
  have ind_U: "~ lin_dep (set us)"
    and ind_W: "~ lin_dep (set ws)"
    and ind_wU: "~ lin_dep (insert w (set us))"
    and ind_UW: "~ lin_dep (set (us @ ws))"
    by (subst subset_li_is_li[OF ind];auto)+
  have corth: "corthogonal us" using Cons by simp
  have U'def: "gram_schmidt_sub n ((?v + w)#us) ws = us'" using Cons by simp

  have v: "?v : carrier_vec n" using dist_U by auto
  hence vw: "?v + w : carrier_vec n" by auto
  hence vwU: "set ((?v + w) # us)  carrier_vec n" by auto
  have vsU: "?v : span (set us)" using adjuster_in_span[OF w] dist by auto
  hence vsUW: "?v : span (set (us @ ws))"
    using span_is_monotone[of "set us" "set (us@ws)"] by auto
  have wsU: "w  span (set us)"
    using lin_dep_iff_in_span[OF U ind_U w w_U] ind_wU by auto
  hence vwU: "?v + w  span (set us)" using adjust_not_in_span[OF w U dist_U] by auto

  have "w  span (set (us@ws))" using lin_dep_iff_in_span[OF _ ind_UW] dist ind by auto
  hence span: "?v + w  span (set (us@ws))" using span_add[OF UW vsUW w] by auto
  hence vwUS: "?v + w  set (us @ ws)" using span_mem by auto
  hence ind2: "~ lin_dep (set (((?v + w) # us) @ ws))"
    using lin_dep_iff_in_span[OF UW ind_UW vw] span by auto

  have vwU: "set ((?v + w) # us)  carrier_vec n" using U w dist by auto
  have dist2: "distinct (((?v + w) # us) @ ws)" using dist vwUS by simp

  have orth2: "corthogonal ((adjuster n w us + w) # us)"
    using adjust_orthogonal[OF U corth w wsU].

  show ?case
    using Cons(1)[OF U'def W vwU dist2 ind2] orth2
    using span_Un[OF vwU wU gram_schmidt_sub_span[OF w U dist_U] W W] by auto
    
qed simp

lemma gram_schmidt_hd [simp]:
  assumes [simp]: "w : carrier_vec n" shows "hd (gram_schmidt n (w#ws)) = w"
  unfolding gram_schmidt_code by simp

theorem gram_schmidt_result:
  assumes ws: "set ws  carrier_vec n"
    and dist: "distinct ws"
    and ind: "~ lin_dep (set ws)"
    and us: "us = gram_schmidt n ws"
  shows "span (set ws) = span (set us)"
    and "corthogonal us"
    and "set us  carrier_vec n"
    and "length us = length ws"
    and "distinct us"
proof -
  have main: "gram_schmidt_sub n [] ws = rev us"
    using us unfolding gram_schmidt_def
    using gram_schmidt_sub_eq by auto
  have orth: "corthogonal []" by auto
  have "span (set ws) = span (set (rev us))"
   and orth2: "corthogonal (rev us)"
   and "set us  carrier_vec n"
   and "length us = length ws"
   and dist: "distinct us" 
    using gram_schmidt_sub_result[OF main ws]
    by (auto simp: assms orth)
  thus "span (set ws) = span (set us)" by simp
  show "set us  carrier_vec n" by fact
  show "length us = length ws" by fact
  show "distinct us" by fact
  show "corthogonal us"
    using corthogonal_distinct[OF orth2] unfolding distinct_rev
    using corthogonal_sort[OF _ set_rev orth2] by auto
qed
end

end

Theory Schur_Decomposition

(*  
    Author:      René Thiemann 
                 Akihisa Yamada
    License:     BSD
*)
section ‹Schur Decomposition›

text ‹We implement Schur decomposition as an algorithm which, given a square matrix $A$
  and a list eigenvalues, computes $B$, $P$, and $Q$ such that 
  $A = PBQ$, $B$ is upper-triangular and $PQ = 1$. The algorithm works is generic in
  the kind of field and can be applied on the rationals, the reals, and the complex numbers.
  The algorithm relies on the method of Gram-Schmidt to create an orthogonal basis,
  and on the Gauss-Jordan algorithm to find eigenvectors to a given eigenvalue.
  
 The algorithm is a key ingredient to show that every matrix with a linear factorizable 
 characteristic polynomial has a Jordan normal form. 

  A further consequence of the algorithm is that the characteristic polynomial of 
  a block diagonal matrix is the product of the characteristic polynomials of the blocks.›

theory Schur_Decomposition
imports 
  Polynomial_Interpolation.Missing_Polynomial
  Gram_Schmidt 
  Char_Poly
begin

definition vec_inv :: "'a::conjugatable_field vec  'a vec"
  where "vec_inv v = 1 / (v ∙c v) v conjugate v"

lemma vec_inv_closed[simp]: "v  carrier_vec n  vec_inv v  carrier_vec n"
  unfolding vec_inv_def by auto

lemma vec_inv_dim[simp]: "dim_vec (vec_inv v) = dim_vec v"
  unfolding vec_inv_def by auto

lemma vec_inv[simp]:
  assumes v: "v : carrier_vec n"
      and v0: "(v::'a::conjugatable_ordered_field vec)  0v n"
  shows "vec_inv v  v = 1"
proof -
  { assume "v ∙c v = 0"
    hence "v = 0v n" using conjugate_square_eq_0_vec[OF v] by auto
    hence False using v0 by auto
  }
  moreover have "conjugate v  v = v ∙c v"
    apply (rule comm_scalar_prod) using v by auto
  ultimately show ?thesis
    unfolding vec_inv_def
    apply (subst smult_scalar_prod_distrib)
    using assms by auto
qed

lemma corthogonal_inv:
  assumes orth: "corthogonal (vs ::'a::conjugatable_field vec list)"
      and V: "set vs  carrier_vec n"
  shows "inverts_mat (mat_of_rows n (map vec_inv vs)) (mat_of_cols n vs)"
    (is "inverts_mat ?W ?V")
proof -
  define l where "l = length vs"
  have rW[simp]: "dim_row ?W = l" using l_def by auto
  have cV[simp]:"dim_col ?V = l" using l_def by auto
  have dim: "i. i < length vs  vs!i  carrier_vec n" using V by auto
  show ?thesis
    unfolding inverts_mat_def
    apply rule
    unfolding mat_of_rows_carrier length_map l_def[symmetric]
    unfolding index_one_mat
  proof -
    show "dim_row (?W * ?V) = l" "dim_col (?W * ?V) = l"
      unfolding times_mat_def rW cV by auto
    fix i j assume i:"i<l" and j: "j<l"
    hence i2: "i<length vs"
      and i3: "i<length (map vec_inv vs)"
      and j2: "j<length vs" using l_def by auto
    hence id2: "vs ! i  carrier_vec n"
      and id3: "map vec_inv vs ! i  carrier_vec n"
      and id4: "conjugate (vs ! i)  carrier_vec n"
      and jd2: "vs ! j  carrier_vec n" using dim by auto
    show "(?W * ?V) $$ (i,j) = (if i = j then 1 else 0)"
      unfolding times_mat_def rW cV
      unfolding index_mat[OF i j] split
      unfolding mat_of_rows_row[OF i3 id3]
      unfolding col_mat_of_cols[OF j2 jd2]
      unfolding nth_map[OF i2]
      unfolding vec_inv_def
      unfolding smult_scalar_prod_distrib[OF id4 jd2]
      unfolding comm_scalar_prod[OF id4 jd2]
      using corthogonalD[OF orth j2 i2] by auto
  qed
qed

definition corthogonal_inv :: "'a::conjugatable_field mat  'a mat"
  where "corthogonal_inv A = mat_of_rows (dim_row A) (map vec_inv (cols A))"

definition mat_adjoint :: "'a :: conjugatable_field mat  'a mat"
  where "mat_adjoint A  mat_of_rows (dim_row A) (map conjugate (cols A))"

definition corthogonal_mat :: "'a::conjugatable_field mat  bool"
  where "corthogonal_mat A 
    let B = mat_adjoint A * A in
    diagonal_mat B  (i<dim_col A. B $$ (i,i)  0)"

lemma corthogonal_matD[elim]:
  assumes orth: "corthogonal_mat A"
      and i: "i < dim_col A"
      and j: "j < dim_col A"
  shows "(col A i ∙c col A j = 0) = (i  j)"
proof
  have ci: "col A i : carrier_vec (dim_row A)"
   and cj: "col A j : carrier_vec (dim_row A)" by auto
  note [simp] = conjugate_conjugate_sprod[OF ci cj]

  let ?B = "mat_adjoint A * A"
  have diag: "diagonal_mat ?B" and zero: "i. i<dim_col A  ?B $$ (i,i)  0"
    using orth unfolding corthogonal_mat_def Let_def by auto
  { assume "i = j"
    hence "conjugate (col A i)  col A j  0"
      using zero[OF i] unfolding mat_adjoint_def using i by simp
    hence "conjugate (conjugate (col A i)  col A j)  0"
      unfolding conjugate_zero_iff.
    hence "col A i ∙c col A j  0" by simp
  }
  thus "col A i ∙c col A j = 0  i  j" by auto
  { assume "i  j"
    hence "conjugate (col A i)  col A j = 0"
      using diag
      unfolding diagonal_mat_def
      unfolding mat_adjoint_def using i j by simp
    hence "conjugate (conjugate (col A i)  col A j) = 0" by simp
    thus "col A i ∙c col A j = 0" by simp
  }
qed

lemma corthogonal_matI[intro]:
  assumes "(i j. i < dim_col A  j < dim_col A  (col A i ∙c col A j = 0) = (i  j))"
  shows "corthogonal_mat A"
proof -
  { fix i j assume i: "i < dim_col A" and j: "j < dim_col A" and ij: "i  j"
    have "conjugate (col A i)  col A j = 0"
      by (metis assms col_dim i j ij conjugate_vec_sprod_comm)
  }
  moreover
  { fix i assume "i < dim_col A"
    hence "conjugate (col A i)  col A i  0"
      by (metis assms comm_scalar_prod carrier_vec_conjugate carrier_vecI)
  }
  ultimately show ?thesis
  unfolding corthogonal_mat_def Let_def
  unfolding diagonal_mat_def
  unfolding mat_adjoint_def by auto
qed

lemma corthogonal_inv_result:
  assumes o: "corthogonal_mat (A::'a::conjugatable_field mat)"
  shows "inverts_mat (corthogonal_inv A) A"
proof -
  have oc: "corthogonal (cols A)"
    apply (intro corthogonalI) using corthogonal_matD[OF o] by auto
  show ?thesis unfolding corthogonal_inv_def
    using corthogonal_inv[OF oc cols_dim] by auto
qed

text "extends a vector to a basis"

definition basis_completion :: "'a::ring_1 vec  'a vec list" where
  "basis_completion v  let 
     n = dim_vec v;
     drop_index = hd ([ i . i <- [0..<n], v $ i  0]);
     vs = [unit_vec n i. i <- [0..<n], i  drop_index] 
   in v # vs"

lemma (in vec_space) basis_completion: fixes v :: "'a :: field vec"
  assumes v: "v  carrier_vec n"
      and v0: "v  0v n"
  shows 
    "basis (set (basis_completion v))"
    "set (basis_completion v)  carrier_vec n"
    "span (set (basis_completion v)) = carrier_vec n" 
    "distinct (basis_completion v)"
    "¬ lin_dep (set (basis_completion v))"
    "length (basis_completion v) = n"
    "hd (basis_completion v) = v"
proof -
  let ?b = "basis_completion v"
  note d = basis_completion_def Let_def
  from v have dim: "dim_vec v = n" by auto
  let ?is = "[ i . i <- [0..<n], v $ i  0]"
  {
    assume empty: "set ?is = {}"
    have "v = 0v n"
      by (rule eq_vecI, insert empty v, auto)
  }
  with v0 obtain k ids where id: "?is = k # ids" and mem: "k  set ?is" by (cases ?is, auto)
  from mem have vk: "v $ k  0" and k: "k < n" by auto
  {
    fix i 
    assume i: "¬ i < k"
    have id: "k # [Suc k..<n] = [k ..< n]" using k by (simp add: upt_conv_Cons)
    from i have "i < n  (k # [Suc k..<n]) ! (i - k) = i" 
      unfolding id
      by (subst nth_upt, auto)
  }
  hence split: "[0 ..< n] = [0 ..< k] @ k # [Suc k ..< n]"
    by (intro nth_equalityI, insert k, auto simp: nth_append) 
  {
    fix as
    assume "k  set as"
    hence "[unit_vec n i. i <- as, i  k] = [unit_vec n i. i <- as]"
      by (induct as, auto)
  } note conv = this
  have b_all: "?b = v # [unit_vec n i. i <- [0..<n], i  k]"
    unfolding d dim id by simp 
  also have "[unit_vec n i. i <- [0..<n], i  k] = [unit_vec n i. i <- [0..<k]] @ [unit_vec n i. i <- [Suc k..<n]]"
    unfolding split by (auto simp: conv)
  finally have b: "?b = v # [unit_vec n i. i <- [0..<k]] @ [unit_vec n i. i <- [Suc k..<n]]" by simp
  show carr: "set ?b  carrier_vec n" (is "?S  _")
    unfolding b using assms by auto
  show "hd ?b = v" unfolding b by auto
  show len: "length (basis_completion v) = n" unfolding b using k
    by auto
  define I where "I = (λ i. if i < k then i else Suc i)"
  have I: " i. I i  k" " i. Suc i < n  I i < n" unfolding I_def by auto
  {
    fix i
    assume i: "i < n"
    have "?b ! i = (if i = 0 then v else unit_vec n (I (i - 1)))"
      unfolding b I_def using i
      by (auto split: if_splits simp: nth_append)
  } note bi = this
  show dist: "distinct ?b" unfolding distinct_conv_nth len
  proof (intro allI impI)
    fix i j
    assume i: "i < n" and j: "j < n" and ij: "i  j"
    show "?b ! i  ?b ! j"
    proof 
      assume id1: "?b ! i = ?b ! j" 
      hence id2: " l. ?b ! i $ l = ?b ! j $ l" by auto
      have "i = j" 
      proof (cases "i = 0")
        case True
        hence biv: "?b ! i = v" unfolding b by simp
        from True ij have bj: "?b ! j = unit_vec n (I (j - 1))" "Suc (j - 1) = j" unfolding bi[OF j] by auto
        with id2[of k, unfolded biv bj] vk I[of "j - 1"] k j
        have False by simp
        thus ?thesis ..
      next
        case False note i0 = this
        hence bi': "?b ! i = unit_vec n (I (i - 1))" "Suc (i - 1) = i" unfolding bi[OF i] by auto
        show ?thesis
        proof (cases "j = 0")
          case True
          hence bj: "?b ! j = v" unfolding b by simp
          from id2[of k, unfolded bi' bj] vk I[of "i - 1"] k i bi'
          have False by simp
          thus ?thesis by simp
        next
          case False note j0 = this
          hence bj: "?b ! j = unit_vec n (I (j - 1))" "Suc (j - 1) = j" unfolding bi[OF j] by auto
          have "1 = ?b ! i $ I (i - 1)" unfolding bi' using I[of "i - 1"] i i0 by auto
          also have " = unit_vec n (I (j - 1)) $ I (i - 1)" unfolding id1 bj by simp
          also have " = (if I (i - 1) = I (j - 1) then 1 else 0)"
            using I[of "i - 1"] I[of "j - 1"] i0 j0 i j by auto
          finally have "I (i - 1) = I (j - 1)" by (auto split: if_splits)
          with i0 j0 show "i = j" unfolding I_def by (auto split: if_splits)
        qed
      qed   
      thus False using ij by simp
    qed
  qed
  have "span (set ?b)  carrier_vec n" using carr by auto
  moreover
  {
    fix w :: "'a vec"
    assume w: "w  carrier_vec n"
    define lookup where "lookup = (v,k) # [(unit_vec n i, i). i <- [0..<n], i  k]"
    define a where "a = (λ vi. case map_of lookup vi of Some i  if i = k then w $ k / v $ k else
       w $ i - w $ k / v $ k * v $ i)" 
    have "map fst lookup = ?b" unfolding b_all lookup_def 
      by (auto simp: map_concat o_def if_distrib, unfold list.simps fst_def prod.simps, simp)
    with dist have dist: "distinct (map fst lookup)" by simp
    let ?w = "lincomb a (set ?b)"
    have "?w  carrier_vec n" using carr by auto
    with w have dim: "dim_vec w = n" "dim_vec ?w = n" by auto
    have "w = ?w" 
    proof (rule eq_vecI; unfold dim)
      fix i
      assume i: "i < n"
      show "w $ i = ?w $ i" unfolding lincomb_def 
      proof (subst finsum_index[OF i _ carr]) 
        show "(λv. a v v v)  set ?b  carrier_vec n" using carr by auto
        {
          fix x :: "'a vec" and j
          assume "x = unit_vec n j" "j  k" "j < n"
          hence "(x,j)  set lookup" unfolding lookup_def by auto
          from map_of_is_SomeI[OF dist this]
          have "a x = w $ j - w $ k / v $ k * v $ j" unfolding a_def using j  k by auto
        } note a = this          
        have "(xset ?b. (a x v x) $ i) = (a v v v) $ i + (x(set ?b) - {v}. (a x v x) $ i)"
          by (rule sum.remove[OF finite_set], auto simp: b)
        also have "a v = w $ k / v $ k" unfolding a_def lookup_def by auto
        also have "( v v) $ i = w $ k / v $ k * v $ i" using i v by auto
        finally have "(xset ?b. (a x v x) $ i) = w $ k / v $ k * v $ i + (x(set ?b) - {v}. (a x v x) $ i)" .
        also have " = w $ i"
        proof (cases "i = k")
          case True
          hence "w $ k / v $ k * v $ i = w $ k" using vk by auto
          moreover have "(x(set ?b) - {v}. (a x v x) $ i) = 0" unfolding True
          proof (rule sum.neutral, intro ballI)
            fix x
            assume "x  set ?b - {v}"
            then obtain j where x: "x = unit_vec n j" "j  k" "j < n" using k unfolding b by auto
            show "(a x v x) $ k = 0" unfolding a[OF x] unfolding x using x k by auto
          qed
          ultimately show ?thesis unfolding True by simp
        next
          case False
          let ?ui = "unit_vec n i :: 'a vec"
          {
            assume "?ui = v"
            from arg_cong[OF this, of "λ v. v $ k"] vk i k False have False by auto
          }
          hence diff: "?ui  v" by auto
          from a[OF refl False] have ai: "(a ?ui v ?ui) $ i = w $ i - w $ k / v $ k * v $ i" 
            using i by auto          
          have "?ui  set ?b" unfolding b_all using False k i by auto
          with diff have mem: "unit_vec n i  set ?b - {v}" by auto
          have "w $ k / v $ k * v $ i + (x(set ?b) - {v}. (a x v x) $ i)
            = w $ i + (x(set ?b) - {v,?ui}. (a x v x) $ i)"
            by (subst sum.remove[OF _ mem], auto simp: ai intro!: sum.cong)
          also have "(x(set ?b) - {v,?ui}. (a x v x) $ i) = 0"
            by (rule sum.neutral, unfold b_all, insert i k, auto)
          finally show ?thesis by simp
        qed
        finally show "w $ i = (xset ?b. (a x v x) $ i)" by simp
      qed
    qed auto
    hence "w  span (set ?b)" unfolding span_def by auto
  }
  ultimately show span: "span (set ?b) = carrier_vec n" by blast
  show "basis (set ?b)"
  proof (rule dim_gen_is_basis[OF finite_set carr span])
    have "card (set ?b) = dim" using dist len distinct_card unfolding dim_is_n by blast
    thus "card (set ?b)  dim" by simp
  qed
  thus "¬ lin_dep (set ?b)" unfolding basis_def by auto
qed

lemma orthogonal_mat_of_cols:
  assumes W: "set ws  carrier_vec n"
    and orth: "corthogonal ws"
    and len: "length ws = n"
  shows "corthogonal_mat (mat_of_cols n ws)" (is "corthogonal_mat ?W")
proof
    fix i j assume i: "i < dim_col ?W" and j: "j < dim_col ?W"
    hence [simp]: "ws ! i : carrier_vec n" "ws ! j : carrier_vec n"
      using W len by auto
    have "i < length ws" and "j < length ws" using i j using len W by auto
    thus "col ?W i ∙c col ?W j = 0  i  j"
      using orth
      unfolding corthogonal_def
      by simp
qed

lemma corthogonal_col_ev_0: fixes A :: "'a :: conjugatable_ordered_field mat"
  assumes A: "A  carrier_mat n n"
  and v: "v  carrier_vec n"
  and v0: "v  0v n"
  and eigen[simp]: "A *v v = e v v"
  and n: "n  0"
  and hdws: "hd ws = v"
  and ws: "set ws  carrier_vec n" "corthogonal ws" "length ws = n"
  defines "W == mat_of_cols n ws"
  defines "W' == corthogonal_inv W"
  defines "A' == W' * A * W"
  shows "col A' 0 = vec n (λ i. if i = 0 then e else 0)"
proof -
  let ?f = "(λ i. if i = 0 then e else 0)"
  from ws have W: "W  carrier_mat n n" unfolding W_def by auto
  from W have W': "W'  carrier_mat n n" unfolding W'_def 
    corthogonal_inv_def mat_of_rows_def by auto
  from A W W' have A': "A'  carrier_mat n n" unfolding A'_def by auto
  show "col A' 0 = vec n ?f"
  proof (rule,unfold dim_vec)
    show dim: "dim_vec (col A' 0) = n" using A' by simp
    have row0: "vec_inv v  (A *v v) = e"
      using scalar_prod_smult_distrib[OF vec_inv_closed[OF v] v]
      using vec_inv[OF v v0] by auto
    fix i assume i: "i < n"
    hence i2: "i < length ws" using ws by auto
    let ?wsi = "ws ! i"
    have z: "0 < dim_col A'" using A' n by auto
    hence z2: "0 < length ws" using A' ws by auto
    have wsi[simp]: "ws!i : carrier_vec n" using ws i by auto
    hence ws0[simp]: "ws!0 = v" using hd_conv_nth[symmetric] hdws z2 by auto
    have "col A' 0 $ i = A' $$ (i, 0)" using A' i by auto
    also have "... = (W' * (A * W)) $$ (i, 0)" unfolding A'_def using W' A W by auto
    also have "... = row W' i  col (A * W) 0"
      apply (subst index_mult_mat) using W W' A i by auto
    also have "row W' i = vec_inv ?wsi"
      unfolding W'_def W_def unfolding corthogonal_inv_def using i ws by auto
    also have "col (A * W) 0 = A *v col W 0" using A W z A' by auto
    also have "col W 0 = v" unfolding W_def using z2 ws0 n col_mat_of_cols v by blast
    also have "A *v v = e v v" using eigen.
    also have "vec_inv ?wsi  (e v v) = e * (vec_inv ?wsi  v)"
      using scalar_prod_smult_distrib[OF vec_inv_closed[OF wsi] v].
    also have "... = ?f i"
    proof(cases "i = 0")
      case True thus ?thesis using vec_inv[OF v v0] by simp
    next 
      case False
      hence z: "0 < length ws" using i ws by auto
      note cwsi = carrier_vec_conjugate[OF wsi]
      have "vec_inv ?wsi  v = 1 / (?wsi ∙c ?wsi) * (conjugate ?wsi  v)"
        unfolding vec_inv_def unfolding smult_scalar_prod_distrib[OF cwsi v].. 
      also have "conjugate ?wsi  v = v ∙c ?wsi"
        using comm_scalar_prod[OF cwsi v].
      also have "... = 0"
        using corthogonalD[OF ws(2) z i2] False unfolding ws0 by auto
      finally show ?thesis using False by auto
    qed
    also have "... = vec n ?f $ i" using i by simp
    finally show "col A' 0 $ i = vec n ?f $ i" .
  qed
qed


text "Schur decomposition"
fun schur_decomposition :: "'a::conjugatable_field mat  'a list  'a mat × 'a mat × 'a mat" where 
  "schur_decomposition A [] = (A, 1m (dim_row A), 1m (dim_row A))"
| "schur_decomposition A (e # es) = (let
       n = dim_row A;
       n1 = n - 1;
       v = find_eigenvector A e;
       ws = gram_schmidt n (basis_completion v);
       W = mat_of_cols n ws;
       W' = corthogonal_inv W;
       A' = W' * A * W;
       (A1,A2,A0,A3) = split_block A' 1 1;
       (B,P,Q) = schur_decomposition A3 es;
       z_row = (0m 1 n1);
       z_col = (0m n1 1);
       one_1 = 1m 1
     in (four_block_mat A1 (A2 * P) A0 B, 
     W * four_block_mat one_1 z_row z_col P, 
     four_block_mat one_1 z_row z_col Q * W'))"


theorem schur_decomposition:
  assumes A: "(A::'a::conjugatable_ordered_field mat)  carrier_mat n n"
      and c: "char_poly A = ( (e :: 'a)  es. [:- e, 1:])"
      and B: "schur_decomposition A es = (B,P,Q)"
  shows "similar_mat_wit A B P Q  upper_triangular B  diag_mat B = es"
  using assms
proof (induct es arbitrary: n A B P Q)
  case Nil
  with degree_monic_char_poly[of A n]
  show ?case by (auto intro: similar_mat_wit_refl simp: diag_mat_def)
next
  case (Cons e es n A C P Q)
  let ?n1 = "n - 1"
  from Cons have A: "A  carrier_mat n n" and dim: "dim_row A = n" by auto
  let ?cp = "char_poly A"
  from Cons(3)
  have cp: "?cp = [: -e, 1 :] * (e  es. [:- e, 1:])" by auto
  have mon: "monic (e es. [:- e, 1:])" by (rule monic_prod_list, auto)
  have deg: "degree ?cp = Suc (degree (e es. [:- e, 1:]))" unfolding cp
    by (subst degree_mult_eq, insert mon, auto)
  with degree_monic_char_poly[OF A] have n: "n  0" by auto
  define v where "v = find_eigenvector A e"
  define b where "b = basis_completion v"
  define ws where "ws = gram_schmidt n b"
  define W where "W = mat_of_cols n ws"
  define W' where "W' = corthogonal_inv W"
  define A' where "A' = W' * A * W"
  obtain A1 A2 A0 A3 where splitA': "split_block A' 1 1 = (A1,A2,A0,A3)"
    by (cases "split_block A' 1 1", auto)
  obtain B P' Q' where schur: "schur_decomposition A3 es = (B,P',Q')" 
    by (cases "schur_decomposition A3 es", auto)
  let ?P' = "four_block_mat (1m 1) (0m 1 ?n1) (0m ?n1 1) P'"
  let ?Q' = "four_block_mat (1m 1) (0m 1 ?n1) (0m ?n1 1) Q'"
  have C: "C = four_block_mat A1 (A2 * P') A0 B" and P: "P = W * ?P'" and Q: "Q = ?Q' * W'"
    using Cons(4) unfolding schur_decomposition.simps
    Let_def list.sel dim
    v_def[symmetric] b_def[symmetric] ws_def[symmetric] W'_def[symmetric] W_def[symmetric]
    A'_def[symmetric] split splitA' schur by auto
  have e: "eigenvalue A e" 
    unfolding eigenvalue_root_char_poly[OF A] cp by simp
  from find_eigenvector[OF A e] have ev: "eigenvector A v e" unfolding v_def .
  from this[unfolded eigenvector_def]
  have v[simp]: "v  carrier_vec n" and v0: "v  0v n" using A by auto
  interpret cof_vec_space n "TYPE('a)" .
  from basis_completion[OF v v0, folded b_def]
  have span_b: "span (set b) = carrier_vec n" and dist_b: "distinct b" 
    and indep: "¬ lin_dep (set b)" and b: "set b  carrier_vec n" and hdb: "hd b = v" 
    and len_b: "length b = n" by auto
  from hdb len_b n obtain vs where bv: "b = v # vs" by (cases b, auto)
  from gram_schmidt_result[OF b dist_b indep refl, folded ws_def]
  have ws: "set ws  carrier_vec n" "corthogonal ws" "length ws = n" 
    by (auto simp: len_b)
  from gram_schmidt_hd[OF v, of vs, folded bv] have hdws: "hd ws = v" unfolding ws_def .
  have orth_W: "corthogonal_mat W" using orthogonal_mat_of_cols ws unfolding W_def.
  have W: "W  carrier_mat n n"
    using ws unfolding W_def using mat_of_cols_carrier(1)[of n ws] by auto
  have W': "W'  carrier_mat n n" unfolding W'_def corthogonal_inv_def using W 
    by (auto simp: mat_of_rows_def)  
  from corthogonal_inv_result[OF orth_W] 
  have W'W: "inverts_mat W' W" unfolding W'_def .
  hence WW': "inverts_mat W W'" using mat_mult_left_right_inverse[OF W' W] W' W
    unfolding inverts_mat_def by auto
  have A': "A'  carrier_mat n n" using W W' A unfolding A'_def by auto
  have A'A_wit: "similar_mat_wit A' A W' W"
    by (rule similar_mat_witI[of _ _ n], insert W W' A A' W'W WW', auto simp: A'_def
    inverts_mat_def)
  hence A'A: "similar_mat A' A" unfolding similar_mat_def by blast
  from similar_mat_wit_sym[OF A'A_wit] have simAA': "similar_mat_wit A A' W W'" by auto
  have eigen[simp]: "A *v v = e v v" and v0: "v  0v n"
    using v_def find_eigenvector[OF A e] A
    unfolding eigenvector_def by auto
  let ?f = "(λ i. if i = 0 then e else 0)"
  have col0: "col A' 0 = vec n ?f"
    unfolding A'_def W'_def W_def
    using corthogonal_col_ev_0[OF A v v0 eigen n hdws ws].
  from A' n have "dim_row A' = 1 + ?n1" "dim_col A' = 1 + ?n1" by auto
  from split_block[OF splitA' this] have A2: "A2  carrier_mat 1 ?n1"
    and A3: "A3  carrier_mat ?n1 ?n1" 
    and A'block: "A' = four_block_mat A1 A2 A0 A3" by auto
  have A1id: "A1 = mat 1 1 (λ _. e)"
    using splitA'[unfolded split_block_def Let_def] arg_cong[OF col0, of "λ v. v $ 0"] A' n
    by (auto simp: col_def)
  have A1: "A1  carrier_mat 1 1" unfolding A1id by auto
  {
    fix i
    assume "i < ?n1"
    with arg_cong[OF col0, of "λ v. v $ Suc i"] A'
    have "A' $$ (Suc i, 0) = 0" by auto
  } note A'0 = this
  have A0id: "A0 = 0m ?n1 1"
    using splitA'[unfolded split_block_def Let_def] A'0 A' by auto
  have A0: "A0  carrier_mat ?n1 1" unfolding A0id by auto
  from cp char_poly_similar[OF A'A]
  have cp: "char_poly A' = [: -e,1 :] * ( e  es. [:- e, 1:])" by simp
  also have "char_poly A' = char_poly A1 * char_poly A3" 
    unfolding A'block A0id
    by (rule char_poly_four_block_zeros_col[OF A1 A2 A3])
  also have "char_poly A1 = [: -e,1 :]"
    by (simp add: A1id char_poly_defs det_def signof_def sign_def)
  finally have cp: "char_poly A3 = ( e  es. [:- e, 1:])"
    by (metis mult_cancel_left pCons_eq_0_iff zero_neq_one)
  from Cons(1)[OF A3 cp schur]
  have simIH: "similar_mat_wit A3 B P' Q'" and ut: "upper_triangular B" and diag: "diag_mat B = es" by auto
  from similar_mat_witD2[OF A3 simIH] 
  have B: "B  carrier_mat ?n1 ?n1" and P': "P'  carrier_mat ?n1 ?n1" and Q': "Q'  carrier_mat ?n1 ?n1" 
    and PQ': "P' * Q' = 1m ?n1" by auto
  have A0_eq: "A0 = P' * A0 * 1m 1" unfolding A0id using P' by auto
  have simA'C: "similar_mat_wit A' C ?P' ?Q'" unfolding A'block C
    by (rule similar_mat_wit_four_block[OF similar_mat_wit_refl[OF A1] simIH _ A0_eq A1 A3 A0],
    insert PQ' A2 P' Q', auto)
  have ut1: "upper_triangular A1" unfolding A1id by auto
  have ut: "upper_triangular C" unfolding C A0id
    by (intro upper_triangular_four_block[OF _ B ut1 ut], auto simp: A1id)
  from A1id have diagA1: "diag_mat A1 = [e]" unfolding diag_mat_def by auto
  from diag_four_block_mat[OF A1 B] have diag: "diag_mat C = e # es" unfolding diag diagA1 C by simp
  from ut similar_mat_wit_trans[OF simAA' simA'C, folded P Q] diag
  show ?case by blast
qed

definition schur_upper_triangular :: "'a::conjugatable_field mat  'a list  'a mat" where 
  "schur_upper_triangular A es = (case schur_decomposition A es of (B,_,_)  B)"


lemma schur_upper_triangular:
  assumes A: "(A :: 'a :: conjugatable_ordered_field mat)  carrier_mat n n"
  and linear: "char_poly A = ( a  es. [:- a, 1:])"
  defines B: "B  schur_upper_triangular A es"
  shows "B  carrier_mat n n" "upper_triangular B" "similar_mat A B" 
proof -
  let ?B = "schur_upper_triangular A es"
  obtain C P Q where schur: "schur_decomposition A es = (C,P,Q)" 
    by (cases "schur_decomposition A es", auto)
  hence B: "B = C" using A unfolding schur_upper_triangular_def B by auto
  from schur_decomposition[OF A linear schur]
  have sim: "similar_mat_wit A B P Q" and B: "upper_triangular B" unfolding B by auto
  from sim show "similar_mat A B" unfolding similar_mat_def by auto
  from similar_mat_witD2[OF A sim] show "B  carrier_mat n n" by auto
  show "upper_triangular B" by fact
qed

lemma schur_decomposition_exists: assumes A: "A  carrier_mat n n"
  and linear: "char_poly A = ( (a :: 'a :: conjugatable_ordered_field)  es. [:- a, 1:])"
  shows " B  carrier_mat n n. upper_triangular B  similar_mat A B" 
  using schur_upper_triangular[OF A linear] by blast

lemma char_poly_0_block: fixes A :: "'a :: conjugatable_ordered_field mat"
  assumes A: "A = four_block_mat B C (0m m n) D"
  and linearB: " es. char_poly B = ( a  es. [:- a, 1:])"
  and linearD: " es. char_poly D = ( a  es. [:- a, 1:])"
  and B: "B  carrier_mat n n"
  and C: "C  carrier_mat n m"
  and D: "D  carrier_mat m m"
  shows "char_poly A = char_poly B * char_poly D"
proof -
  from linearB obtain bs where cB: "char_poly B = (abs. [:- a, 1:])" by auto
  from linearD obtain ds where cD: "char_poly D = (ads. [:- a, 1:])" by auto
  from schur_decomposition_exists[OF B cB] 
  obtain B' PB QB where sB: "schur_decomposition B bs = (B',PB,QB)" 
    by (cases "schur_decomposition B bs", auto)
  obtain D' PD QD where sD: "schur_decomposition D ds = (D',PD,QD)" 
    by (cases "schur_decomposition D ds", auto)
  from schur_decomposition[OF B cB sB] similar_mat_witD2[OF B, of B'] have 
    simB: "similar_mat B B'" and utB: "upper_triangular B'" and diagB: "diag_mat B' = bs"
    and B': "B'  carrier_mat n n"
    by (auto simp: similar_mat_def)
  from schur_decomposition[OF D cD sD] similar_mat_witD2[OF D, of D'] have 
    simD: "similar_mat D D'" and utD: "upper_triangular D'" and diagD: "diag_mat D' = ds"
    and D': "D'  carrier_mat m m"
    by (auto simp: similar_mat_def)
  let ?z = "0m m n"
  from similar_mat_four_block_0_ex[OF simB simD C B D, folded A]
    obtain B0 where B0: "B0  carrier_mat n m" and sim: "similar_mat A (four_block_mat B' B0 ?z D')" 
    by auto
  let ?block = "four_block_mat B' B0 ?z D'"
  let ?cp = char_poly
  let ?prod = "QB * C * PD"
  let ?diag = "λ A. (adiag_mat A. [:- a, 1:])"
  from char_poly_similar[OF sim] have "?cp A = ?cp ?block" by simp
  also have " = ?diag ?block"
    by (rule char_poly_upper_triangular[OF four_block_carrier_mat[OF B' D'] upper_triangular_four_block[OF B' D' utB utD]])      
  also have " = ?diag B' * ?diag D'" unfolding diag_four_block_mat[OF B' D']
    by auto
  also have "?diag B' = ?cp B'"
    by (subst char_poly_upper_triangular[OF B' utB], simp)
  also have " = ?cp B"
    by (rule char_poly_similar[OF similar_mat_sym[OF simB]])
  also have "?diag D' = ?cp D'"
    by (subst char_poly_upper_triangular[OF D' utD], simp)
  also have " = ?cp D"
    by (rule char_poly_similar[OF similar_mat_sym[OF simD]])
  finally show ?thesis .
qed


lemma char_poly_0_block': fixes A :: "'a :: conjugatable_ordered_field mat"
  assumes A: "A = four_block_mat B (0m n m) C D"
  and linearB: " es. char_poly B = ( a  es. [:- a, 1:])"
  and linearD: " es. char_poly D = ( a  es. [:- a, 1:])"
  and B: "B  carrier_mat n n"
  and C: "C  carrier_mat m n"
  and D: "D  carrier_mat m m"
  shows "char_poly A = char_poly B * char_poly D"
proof -
  let ?A = "four_block_mat B (0m n m) C D"
  let ?B = "transpose_mat B"
  let ?D = "transpose_mat D"
  have AC: "?A  carrier_mat (n + m) (n + m)" using B D by auto
  from arg_cong[OF transpose_four_block_mat[OF B zero_carrier_mat C D], of char_poly,
    unfolded char_poly_transpose_mat[OF AC], folded A]
  have "char_poly A =
    char_poly (four_block_mat ?B (transpose_mat C) (0m m n) ?D)" by auto
  also have " = char_poly ?B * char_poly ?D"
    by (rule char_poly_0_block[OF refl], insert B C D linearB linearD, auto)
  also have " = char_poly B * char_poly D" using B D
    by simp
  finally show ?thesis .
qed

end

Theory Jordan_Normal_Form_Existence

(*  
    Author:      René Thiemann 
                 Akihisa Yamada
    License:     BSD
*)
section ‹Computing Jordan Normal Forms›

theory Jordan_Normal_Form_Existence
imports 
  Jordan_Normal_Form
  Column_Operations
  Schur_Decomposition
begin

hide_const (open) Coset.order

text‹We prove existence of Jordan normal forms by means of first applying Schur's algorithm
 to convert a matrix into upper-triangular form, and then applying the following algorithm 
 to convert a upper-triangular matrix into a Jordan normal form. It only consists of 
 basic row- and column-operations.›

subsection ‹Pseudo Code Algorithm›

text ‹The following algorithm is used to compute JNFs from upper-triangular matrices.
It was generalized from \cite[Sect.~11.1.4]{PO07} where this algorithm was not
explicitly specified but only applied on an example. We further introduced step 2
which does not occur in the textbook description.

\begin{enumerate} 
\item[1.] Eliminate entries within blocks besides EV $a$ and above EV $b$ for $a \neq b$:
    for $A_{ij}$ with EV $a$ left of it, and EV $b$ below of it, perform
      @{term "add_col_sub_row (Aij / (b - a)) i j"}.
    The iteration should be by first increasing $j$ and the inner loop by decreasing $i$.
      
\item[2.] Move rows of same EV together, can only be done after 1., 
    otherwise triangular-property is lost.
    Say both rows $i$ and $j$ ($i < j$) contain EV $a$, but all rows between $i$ and $j$ have different EV.
    Then perform
      @{term "swap_cols_rows (i+1) j"}, 
      @{term "swap_cols_rows (i+2) j"}, 
      \ldots
      @{term "swap_cols_rows (j-1) j"}.
    Afterwards row $j$ will be at row $i+1$, and rows $i+1,\ldots,j-1$ will be moved to $i+2,\ldots,j$.
    The global iteration works by increasing $j$.

\item[3.] Transform each EV-block into JNF, do this for increasing upper $n \times k$ matrices, 
    where each new column $k$ will be treated as follows.
\begin{enumerate}
 \item[a)] Eliminate entries $A_{ik}$ in rows of form $0 \ldots 0\ ev\ 1\ 0 \ldots 0\ A_{ik}$:
         @{term "add_col_sub_row (-Aik) (i+1) k"}.
       Perform elimination by increasing $i$.
\item[b)] Figure out largest JB (of $n-1 \times n-1$ sub-matrix) with lowest row of form $0 \ldots 0\ ev\ 0 \ldots 0\ A_{lk}$
       where $A_{lk} \neq 0$, and set $x := A_{lk}$.
\item[c)] If such a JB does not exist, continue with next column.
  Otherwise, eliminate all other non-zero-entries $y := A_{ik}$ via row $l$:
         @{term "add_col_sub_row (y/x) i l"},
         @{term "add_col_sub_row (y/x) (i-1) (l-1)"},
         @{term "add_col_sub_row (y/x) (i-2) (l-2)"}, \ldots
         where the number of steps is determined by the size of the JB left-above of $A_{ik}$.
         Perform an iteration over $i$.
\item[d)] Normalize value in row $l$ to 1:
         @{term "mult_col_div_row (1/x) k"}.
\item[e)] Move the 1 down from row $l$ to row $k-1$:
         @{term "swap_cols_rows (l+1) k"},
         @{term "swap_cols_rows (l+2) k"},
         \ldots,
         @{term "swap_cols_rows (k-1) k"}.
\end{enumerate}
\end{enumerate}
›


subsection ‹Real Algorithm›

fun lookup_ev :: "'a  nat  'a mat  nat option" where
  "lookup_ev ev 0 A = None"
| "lookup_ev ev (Suc i) A = (if A $$ (i,i) = ev then Some i else lookup_ev ev i A)"

function swap_cols_rows_block :: "nat  nat  'a mat  'a mat" where
  "swap_cols_rows_block i j A = (if i < j then
    swap_cols_rows_block (Suc i) j (swap_cols_rows i j A) else A)"
  by pat_completeness auto
termination by (relation "measure (λ (i,j,A). j - i)") auto

fun identify_block :: "'a :: one mat  nat  nat" where
  "identify_block A 0 = 0"
| "identify_block A (Suc i) = (if A $$ (i,Suc i) = 1 then
    identify_block A i else (Suc i))"

function identify_blocks_main :: "'a :: ring_1 mat  nat  (nat × nat) list  (nat × nat) list" where
  "identify_blocks_main A 0 list = list"
| "identify_blocks_main A (Suc i_end) list = (
    let i_begin = identify_block A i_end
    in identify_blocks_main A i_begin ((i_begin, i_end) # list)
    )"
  by pat_completeness auto

definition identify_blocks :: "'a :: ring_1 mat  nat  (nat × nat)list" where
  "identify_blocks A i = identify_blocks_main A i []"

fun find_largest_block :: "nat × nat  (nat × nat)list  nat × nat" where
  "find_largest_block block [] = block"
| "find_largest_block (m_start,m_end) ((i_start,i_end) # blocks) = 
    (if i_end - i_start  m_end - m_start then
      find_largest_block (i_start,i_end) blocks else
      find_largest_block (m_start,m_end) blocks)"

fun lookup_other_ev :: "'a  nat  'a mat  nat option" where
  "lookup_other_ev ev 0 A = None"
| "lookup_other_ev ev (Suc i) A = (if A $$ (i,i)  ev then Some i else lookup_other_ev ev i A)"

partial_function (tailrec) partition_ev_blocks :: "'a mat  'a mat list  'a mat list" where
  [code]: "partition_ev_blocks A bs = (let n = dim_row A in
    if n = 0 then bs
    else (case lookup_other_ev (A $$ (n-1, n-1)) (n-1) A of 
      None  A # bs 
    | Some i  case split_block A (Suc i) (Suc i) of (UL,_,_,LR)  partition_ev_blocks UL (LR # bs)))"

context 
  fixes n :: nat
  and ty :: "'a :: field itself"
begin

function step_1_main :: "nat  nat  'a mat  'a mat" where
  "step_1_main i j A = (if j  n then A else if i = 0 then step_1_main (j+1) (j+1) A
    else let 
      i' = i - 1;
      ev_left = A $$ (i',i');
      ev_below = A $$ (j,j);
      aij = A $$ (i',j);
      B = if (ev_left  ev_below  aij  0) then add_col_sub_row (aij / (ev_below - ev_left)) i' j A else A
     in step_1_main i' j B)"
  by pat_completeness auto
termination by (relation "measures [λ (i,j,A). n - j, λ (i,j,A). i]") auto

function step_2_main :: "nat  'a mat  'a mat" where
  "step_2_main j A = (if j  n then A 
    else 
      let ev = A $$ (j,j);
        B = (case lookup_ev ev j A of 
          None  A
        | Some i  swap_cols_rows_block (Suc i) j A
          )
     in step_2_main (Suc j) B)"
  by pat_completeness auto
termination by (relation "measure (λ (j,A). n - j)") auto

fun step_3_a :: "nat  nat  'a mat  'a mat" where 
  "step_3_a 0 j A = A"
| "step_3_a (Suc i) j A = (let 
    aij = A $$ (i,j);
    B = (if A $$ (i,i+1) = 1  aij  0 
    then add_col_sub_row (- aij) (Suc i) j A else A)
    in step_3_a i j B)"

fun step_3_c_inner_loop :: "'a  nat  nat  nat  'a mat  'a mat" where
  "step_3_c_inner_loop val l i 0 A = A"
| "step_3_c_inner_loop val l i (Suc k) A = step_3_c_inner_loop val (l - 1) (i - 1) k 
     (add_col_sub_row val i l A)"

fun step_3_c :: "'a  nat  nat  (nat × nat)list  'a mat  'a mat" where
  "step_3_c x l k [] A = A"
| "step_3_c x l k ((i_begin,i_end) # blocks) A = (
    let 
      B = (if i_end = l then A else 
        step_3_c_inner_loop (A $$ (i_end,k) / x) l i_end (Suc i_end - i_begin) A)
      in step_3_c x l k blocks B)"

function step_3_main :: "nat  'a mat  'a mat" where
  "step_3_main k A = (if k  n then A 
    else let 
      B = step_3_a (k-1) k A; ― ‹3_a›
      all_blocks = identify_blocks B k;
      blocks = filter (λ block. B $$ (snd block,k)  0) all_blocks;
      F = (if blocks = [] ― ‹column k› has only 0s›
        then B
        else let          
          (l_start,l) = find_largest_block (hd blocks) (tl blocks); ― ‹3_b›
          x = B $$ (l,k); 
          C = step_3_c x l k blocks B; ― ‹3_c›
          D = mult_col_div_row (inverse x) k C; ― ‹3_d›
          E = swap_cols_rows_block (Suc l) k D ― ‹3_e›
        in E)
      in step_3_main (Suc k) F)"
  by pat_completeness auto
termination by (relation "measure (λ (k,A). n - k)") auto

end

definition step_1 :: "'a :: field mat  'a mat" where
  "step_1 A = step_1_main (dim_row A) 0 0 A"

definition step_2 :: "'a :: field mat  'a mat" where
  "step_2 A = step_2_main (dim_row A) 0 A"

definition step_3 :: "'a :: field mat  'a mat" where
  "step_3 A = step_3_main (dim_row A) 1 A"

declare swap_cols_rows_block.simps[simp del]
declare step_1_main.simps[simp del]
declare step_2_main.simps[simp del]
declare step_3_main.simps[simp del]

function jnf_vector_main :: "nat  'a :: one mat  (nat × 'a) list" where
  "jnf_vector_main 0 A = []"
| "jnf_vector_main (Suc i_end) A = (let 
    i_start = identify_block A i_end 
    in jnf_vector_main i_start A @ [(Suc i_end - i_start, A $$ (i_start,i_start))])"
  by pat_completeness auto

definition jnf_vector :: "'a :: one mat  (nat × 'a) list" where
  "jnf_vector A = jnf_vector_main (dim_row A) A"

definition triangular_to_jnf_vector :: "'a :: field mat  (nat × 'a) list" where
  "triangular_to_jnf_vector A  let B = step_2 (step_1 A)
    in concat (map (jnf_vector o step_3) (partition_ev_blocks B []))"

subsection ‹Preservation of Dimensions›

lemma swap_cols_rows_block_dims_main: 
  "dim_row (swap_cols_rows_block i j A) = dim_row A  dim_col (swap_cols_rows_block i j A) = dim_col A"
proof (induct i j A rule: swap_cols_rows_block.induct)
  case (1 i j A)
  thus ?case unfolding swap_cols_rows_block.simps[of i j A]
    by (auto split: if_splits)
qed

lemma swap_cols_rows_block_dims[simp]: 
  "dim_row (swap_cols_rows_block i j A) = dim_row A"
  "dim_col (swap_cols_rows_block i j A) = dim_col A"
  "A  carrier_mat n n  swap_cols_rows_block i j A  carrier_mat n n"
  using swap_cols_rows_block_dims_main unfolding carrier_mat_def by auto

lemma step_1_main_dims_main: 
  "dim_row (step_1_main n i j A) = dim_row A  dim_col (step_1_main n i j A) = dim_col A"
proof (induct i j A taking: n rule: step_1_main.induct)
  case (1 i j A)
  thus ?case unfolding step_1_main.simps[of n i j A]
    by (auto split: if_splits simp: Let_def)
qed

lemma step_1_main_dims[simp]: 
  "dim_row (step_1_main n i j A) = dim_row A"
  "dim_col (step_1_main n i j A) = dim_col A"
  using step_1_main_dims_main by blast+

lemma step_2_main_dims_main: 
  "dim_row (step_2_main n j A) = dim_row A  dim_col (step_2_main n j A) = dim_col A"
proof (induct j A taking: n rule: step_2_main.induct)
  case (1 j A)
  thus ?case unfolding step_2_main.simps[of n j A]
    by (auto split: if_splits option.splits simp: Let_def)
qed

lemma step_2_main_dims[simp]: 
  "dim_row (step_2_main n j A) = dim_row A"
  "dim_col (step_2_main n j A) = dim_col A"
  using step_2_main_dims_main by blast+

lemma step_3_a_dims_main: 
  "dim_row (step_3_a i j A) = dim_row A  dim_col (step_3_a i j A) = dim_col A"
  by (induct i j A rule: step_3_a.induct, auto split: if_splits simp: Let_def)

lemma step_3_a_dims[simp]: 
  "dim_row (step_3_a i j A) = dim_row A"
  "dim_col (step_3_a i j A) = dim_col A"
  using step_3_a_dims_main by blast+

lemma step_3_c_inner_loop_dims_main: 
  "dim_row (step_3_c_inner_loop val l i j A) = dim_row A  dim_col (step_3_c_inner_loop val l i j A) = dim_col A"
  by (induct val l i j A rule: step_3_c_inner_loop.induct, auto)

lemma step_3_c_inner_loop_dims[simp]: 
  "dim_row (step_3_c_inner_loop val l i j A) = dim_row A"
  "dim_col (step_3_c_inner_loop val l i j A) = dim_col A"
  using step_3_c_inner_loop_dims_main by blast+

lemma step_3_c_dims_main: 
  "dim_row (step_3_c x l k i A) = dim_row A  dim_col (step_3_c x l k i A) = dim_col A"
  by (induct x l k i A rule: step_3_c.induct, auto simp: Let_def)

lemma step_3_c_dims[simp]: 
  "dim_row (step_3_c x l k i A) = dim_row A"
  "dim_col (step_3_c x l k i A) = dim_col A"
  using step_3_c_dims_main by blast+
  
lemma step_3_main_dims_main: 
  "dim_row (step_3_main n k A) = dim_row A  dim_col (step_3_main n k A) = dim_col A"
proof (induct k A taking: n rule: step_3_main.induct)
  case (1 k A)
  thus ?case unfolding step_3_main.simps[of n k A]
    by (auto split: if_splits prod.splits option.splits simp: Let_def)
qed

lemma step_3_main_dims[simp]: 
  "dim_row (step_3_main n j A) = dim_row A"
  "dim_col (step_3_main n j A) = dim_col A"
  using step_3_main_dims_main by blast+

lemma triangular_to_jnf_steps_dims[simp]: 
  "dim_row (step_1 A) = dim_row A"
  "dim_col (step_1 A) = dim_col A"
  "dim_row (step_2 A) = dim_row A"
  "dim_col (step_2 A) = dim_col A"
  "dim_row (step_3 A) = dim_row A"
  "dim_col (step_3 A) = dim_col A"
  unfolding step_1_def step_2_def step_3_def o_def by auto

subsection ‹Properties of Auxiliary Algorithms›

lemma lookup_ev_Some: 
  "lookup_ev ev j A = Some i  
  i < j  A $$ (i,i) = ev  ( k. i < k  k < j  A $$ (k,k)  ev)"
  by (induct j, auto split: if_splits, case_tac "k = j", auto)

lemma lookup_ev_None: "lookup_ev ev j A = None  i < j  A $$ (i,i)  ev"
  by (induct j, auto split: if_splits, case_tac "i = j", auto)

lemma swap_cols_rows_block_index[simp]: 
  "i < dim_row A  i < dim_col A  j < dim_row A  j < dim_col A 
   low  high  high < dim_row A  high < dim_col A  
   swap_cols_rows_block low high A $$ (i,j) = A $$ 
    (if i = low then high else if i > low  i  high then i - 1 else i,
     if j = low then high else if j > low  j  high then j - 1 else j)"
proof (induct low high A rule: swap_cols_rows_block.induct)
  case (1 low high A)
  let ?r = "dim_row A" let ?c = "dim_col A"
  let ?A = "swap_cols_rows_block low high A"
  let ?B = "swap_cols_rows low high A"
  let ?C = "swap_cols_rows_block (Suc low) high ?B"
  note [simp] = swap_cols_rows_block.simps[of low high A]
  from 1(2-) have lh: "low  high" by simp
  show ?case
  proof (cases "low < high")
    case False
    with lh have lh: "low = high" by auto
    from False have "?A = A" by simp
    with lh show ?thesis by auto
  next
    case True
    hence A: "?A = ?C" by simp
    from True lh have "Suc low  high" by simp
    note IH = 1(1)[unfolded swap_cols_rows_carrier, OF True 1(2-5) this 1(7-)]
    note * = 1(2-)
    let ?i = "if i = Suc low then high else if Suc low < i  i  high then i - 1 else i"
    let ?j = "if j = Suc low then high else if Suc low < j  j  high then j - 1 else j"
    have cong: " i j i' j'. i = i'  j = j'  A $$ (i,j) = A $$ (i',j')" by auto
    have ij: "?i < ?r" "?i < ?c" "?j < ?r" "?j < ?c" "low < ?r" "high < ?r" using * True by auto
    show ?thesis unfolding A IH
      by (subst swap_cols_rows_index[OF ij], rule cong, insert * True, auto)
  qed
qed
  
lemma find_largest_block_main: assumes "find_largest_block block blocks = (m_b, m_e)"
  shows "(m_b, m_e)  insert block (set blocks)
   ( b  insert block (set blocks). m_e - m_b  snd b - fst b)"
  using assms
proof (induct block blocks rule: find_largest_block.induct)
  case (2 m_start m_end i_start i_end blocks)  
  let ?res = "find_largest_block (m_start,m_end) ((i_start,i_end) # blocks)"
  show ?case
  proof (cases "i_end - i_start  m_end - m_start")
    case True
    with 2(3-) have "find_largest_block (i_start,i_end) blocks = (m_b, m_e)" by auto
    note IH = 2(1)[OF True this]
    thus ?thesis using True by auto
  next
    case False
    with 2(3-) have "find_largest_block (m_start,m_end) blocks = (m_b, m_e)" by auto
    note IH = 2(2)[OF False this]
    thus ?thesis using False by auto
  qed
qed simp

lemma find_largest_block: assumes bl: "blocks  []"
  and find: "find_largest_block (hd blocks) (tl blocks) = (m_begin, m_end)"
  shows "(m_begin,m_end)  set blocks"
  " i_begin i_end. (i_begin,i_end)  set blocks  m_end - m_begin  i_end - i_begin"
proof -
  from bl have id: "insert (hd blocks) (set (tl blocks)) = set blocks" by (cases blocks, auto)
  from find_largest_block_main[OF find, unfolded id] 
  show "(m_begin,m_end)  set blocks"
    " i_begin i_end. (i_begin,i_end)  set blocks  m_end - m_begin  i_end - i_begin" by auto
qed

context 
  fixes ev :: "'a :: one"
  and A :: "'a mat"
begin

lemma identify_block_main: assumes "identify_block A j = i"
  shows "i  j  (i = 0  A $$ (i - 1, i)  1)  ( k. i  k  k < j  A $$ (k, Suc k) = 1)"
    (is "?P j")
  using assms
proof (induct j)
  case (Suc j)
  show ?case
  proof (cases "A $$ (j, Suc j) = 1")
    case False
    with Suc(2) show ?thesis by auto
  next
    case True
    with Suc(2) have "identify_block A j = i" by simp
    note IH = Suc(1)[OF this] 
    {
      fix k
      assume "k  i" and "k < Suc j"      
      hence "A $$ (k, Suc k) = 1"
      proof (cases "k < j")
        case True
        with IH k  i show ?thesis by auto
      next
        case False
        with k < Suc j have "k = j" by auto
        with True show ?thesis by auto
      qed
    }
    with IH show ?thesis by auto
  qed
qed simp


lemma identify_block_le: "identify_block A i  i"
  using identify_block_main[OF refl] by blast
end


lemma identify_block: assumes "identify_block A j = i"
  shows "i  j"
  "i = 0  A $$ (i - 1, i)  1"
  "i  k  k < j  A $$ (k, Suc k) = 1"
proof -
  let ?ev = "A $$ (j,j)"
  note main = identify_block_main[OF assms]
  from main show "i  j" by blast
  from main show "i = 0  A $$ (i - 1, i)  1" by blast
  assume "i  k"
  with main have main: "k < j  A $$ (k, Suc k) = 1" by blast
  thus "k < j  A $$ (k, Suc k) = 1" by blast
qed
    
lemmas identify_block_le' = identify_block(1)

lemma identify_block_le_rev: "j = identify_block A i  j  i"
  using identify_block_le'[of A i j] by auto
  
termination identify_blocks_main by (relation "measure (λ (_,i,list). i)", 
  auto simp: identify_block_le le_imp_less_Suc)

termination jnf_vector_main by (relation "measure (λ (i,A). i)", 
  auto simp: identify_block_le le_imp_less_Suc)

lemma identify_blocks_main: assumes "(i_start,i_end)  set (identify_blocks_main A i list)" 
  and " i_s i_e. (i_s, i_e)  set list  i_s  i_e  i_e < k"
  and "i  k"
  shows "i_start  i_end  i_end < k" using assms
proof (induct A i list rule: identify_blocks_main.induct)
  case (2 A i_e list)
  obtain i_b where id: "identify_block A i_e = i_b" by force
  note IH = 2(1)[OF id[symmetric]]
  let ?res = "identify_blocks_main A (Suc i_e) list"  
  let ?rec = "identify_blocks_main A i_b ((i_b, i_e) # list)"
  have res: "?res = ?rec" using id by simp
  from 2(2)[unfolded res] have "(i_start, i_end)  set ?rec" .
  note IH = IH[OF this]
  from 2(3-) have iek: "i_e < k" by simp
  from identify_block_le'[OF id] have ibe: "i_b  i_e" .
  from ibe iek have "i_b  k" by simp
  note IH = IH[OF _ this]
  show ?thesis
    by (rule IH, insert ibe iek 2(3-), auto)
qed simp

lemma identify_blocks: assumes "(i_start,i_end)  set (identify_blocks B k)" 
  shows "i_start  i_end  i_end < k"
  using identify_blocks_main[OF assms[unfolded identify_blocks_def]] by auto

subsection ‹Proving Similarity›

context
begin
private lemma swap_cols_rows_block_similar: assumes "A  carrier_mat n n"
  and "j < n" and "i  j"
  shows "similar_mat (swap_cols_rows_block i j A) A"
  using assms
proof (induct i j A rule: swap_cols_rows_block.induct)
  case (1 i j A)
  hence A: "A  carrier_mat n n"
    and jn: "j < n" and ij: "i  j" by auto
  note [simp] = swap_cols_rows_block.simps[of i j]
  show ?case
  proof (cases "i < j")
    case False
    thus ?thesis using similar_mat_refl[OF A] by simp
  next
    case True note ij = this
    let ?B = "swap_cols_rows i j A"
    let ?C = "swap_cols_rows_block (Suc i) j ?B"
    from swap_cols_rows_similar[OF A _ jn, of i] ij jn
    have BA: "similar_mat ?B A" by auto
    have CB: "similar_mat ?C ?B"
      by (rule 1(1)[OF ij _ jn], insert ij A, auto)
    from similar_mat_trans[OF CB BA] show ?thesis using ij by simp
  qed
qed

private lemma step_1_main_similar: "i  j  A  carrier_mat n n  similar_mat (step_1_main n i j A) A"
proof (induct i j A taking: n rule: step_1_main.induct)
  case (1 i j A)
  note [simp] = step_1_main.simps[of n i j A]
  from 1(3-) have ij: "i  j" and A: "A  carrier_mat n n" by auto
  show ?case
  proof (cases "j  n")
    case True
    thus ?thesis using similar_mat_refl[OF A] by simp
  next
    case False 
    hence jn: "j < n" by simp
    note IH = 1(1-2)[OF False]
    show ?thesis
    proof (cases "i = 0")
      case True
      from IH(1)[OF this _ A]
      show ?thesis using jn by (simp, simp add: True)
    next
      case False
      let ?evi = "A $$ (i - 1, i - 1)"
      let ?evj = "A $$ (j,j)"
      let ?B = "if ?evi  ?evj  A $$ (i - 1, j)  0 then 
        add_col_sub_row (A $$ (i - 1, j) / (?evj - ?evi)) (i - 1) j A else A"
      obtain B where B: "B = ?B" by auto
      have Bn: "B  carrier_mat n n" unfolding B using A by simp
      from False ij jn have *: "i - 1 < n" "j < n" "i - 1  j" by auto
      have BA: "similar_mat B A" unfolding B using similar_mat_refl[OF A]
        add_col_sub_row_similar[OF A *] by auto
      from ij have "i - 1  j" by simp
      note IH = IH(2)[OF False refl refl refl refl B this Bn]
      from False jn have id: "step_1_main n i j A = step_1_main n (i - 1) j B"
        unfolding B by (simp add: Let_def)
      show ?thesis unfolding id
        by (rule similar_mat_trans[OF IH BA])
    qed
  qed
qed

private lemma step_2_main_similar: "A  carrier_mat n n  similar_mat (step_2_main n j A) A"
proof (induct j A taking: n rule: step_2_main.induct)
  case (1 j A)
  note [simp] = step_2_main.simps[of n j A]
  from 1(2) have A: "A  carrier_mat n n" .
  show ?case
  proof (cases "j  n")
    case True
    thus ?thesis using similar_mat_refl[OF A] by simp
  next
    case False 
    hence jn: "j < n" by simp
    note IH = 1(1)[OF False]
    let ?look = "lookup_ev (A $$ (j,j)) j A"
    let ?B = "case ?look of 
          None  A
        | Some i  swap_cols_rows_block (Suc i) j A"
    obtain B where B: "B = ?B" by auto
    have id: "step_2_main n j A = step_2_main n (Suc j) B" unfolding B using False by simp
    have Bn: "B  carrier_mat n n" unfolding B using A by (auto split: option.splits)
    have BA: "similar_mat B A" 
    proof (cases ?look)
      case None
      thus ?thesis unfolding B using similar_mat_refl[OF A] by simp
    next
      case (Some i)
      from lookup_ev_Some[OF this]
      show ?thesis unfolding B Some
        by (auto intro: swap_cols_rows_block_similar[OF A jn])
    qed
    show ?thesis unfolding id
      by (rule similar_mat_trans[OF IH[OF refl B Bn] BA])
  qed
qed

private lemma step_3_a_similar: "A  carrier_mat n n  i < j  j < n  similar_mat (step_3_a i j A) A"
proof (induct i j A rule: step_3_a.induct)
  case (1 j A)
  thus ?case by (simp add: similar_mat_refl)
next
  case (2 i j A)
  from 2(2-) have A: "A  carrier_mat n n" and ij: "Suc i < j" and j: "j < n" by auto
  let ?B = "if A $$ (i, i + 1) = 1  A $$ (i, j)  0 
    then add_col_sub_row (- A $$ (i, j)) (Suc i) j A else A"
  obtain B where B: "B = ?B" by auto
  from A have Bn: "B  carrier_mat n n" unfolding B by simp
  note IH = 2(1)[OF refl B Bn _ j]
  have id: "step_3_a (Suc i) j A = step_3_a i j B" unfolding B by (simp add: Let_def)
  from ij j have *: "Suc i < n" "j < n" "Suc i  j" by auto
  have BA: "similar_mat B A" unfolding B
    using similar_mat_refl[OF A] add_col_sub_row_similar[OF A *] by auto
  show ?case unfolding id
    by (rule similar_mat_trans[OF IH BA], insert ij, auto)
qed

private lemma step_3_c_inner_loop_similar: 
  "A  carrier_mat n n  l  i  k - 1  l  k - 1  i  l < n  i < n  
  similar_mat (step_3_c_inner_loop val l i k A) A"
proof (induct val l i k A rule: step_3_c_inner_loop.induct)
  case (1 val l i A)
  thus ?case using similar_mat_refl[of A] by simp
next
  case (2 val l i k A)
  let ?res = "step_3_c_inner_loop val l i (Suc k) A"
  from 2(2-) have A: "A  carrier_mat n n" and 
    *: "l  i" "k  l" "k  i" "l < n" "i < n" by auto
  let ?B = "add_col_sub_row val i l A"
  have BA: "similar_mat ?B A"
    by (rule add_col_sub_row_similar[OF A], insert *, auto)
  have B: "?B  carrier_mat n n" using A unfolding carrier_mat_def by simp
  show ?case
  proof (cases k)
    case 0
    hence id: "?res = ?B" by simp
    thus ?thesis using BA by simp
  next
    case (Suc kk)
    with * have "l - 1  i - 1" "k - 1  l - 1" "k - 1  i - 1" "l - 1 < n" "i - 1 < n" by auto
    note IH = 2(1)[OF B this]
    show ?thesis unfolding step_3_c_inner_loop.simps
      by (rule similar_mat_trans[OF IH BA])
  qed
qed

private lemma step_3_c_similar: 
  "A  carrier_mat n n  l < k  k < n 
   ( i_begin i_end. (i_begin, i_end)  set blocks   i_end  k  i_end - i_begin  l)
   similar_mat (step_3_c x l k blocks A) A"
proof (induct x l k blocks A rule: step_3_c.induct)
  case (1 x l k A)
  thus ?case using similar_mat_refl[OF 1(1)] by simp
next
  case (2 x l k i_begin i_end blocks A)
  let ?res = "step_3_c x l k ((i_begin,i_end) # blocks) A"
  from 2(2-4) have A: "A  carrier_mat n n" and lki: "l < k" "k < n" by auto
  from 2(5) have i: "i_end  k" "i_end - i_begin  l" by auto
  let ?y = "A $$ (i_end,k)"
  let ?inner = "step_3_c_inner_loop (?y / x) l i_end (Suc i_end - i_begin) A"
  obtain B where B: 
    "B = (if i_end = l then A else ?inner)" by auto    
  hence id: "?res = step_3_c x l k blocks B"
    by simp
  have BA: "similar_mat B A" 
  proof (cases "i_end = l")
    case True
    thus ?thesis unfolding B using similar_mat_refl[OF A] by simp
  next
    case False
    hence B: "B = ?inner" and li: "l  i_end" by (auto simp: B)      
    show ?thesis unfolding B 
      by (rule step_3_c_inner_loop_similar[OF A li], insert lki i, auto)
  qed
  have Bn: "B  carrier_mat n n" using A unfolding B carrier_mat_def by simp
  note IH = 2(1)[OF B Bn lki(1-2) 2(5)]
  show ?case unfolding id
    by (rule similar_mat_trans[OF IH BA], auto)
qed

private lemma step_3_main_similar: "A  carrier_mat n n  k > 0  similar_mat (step_3_main n k A) A"
proof (induct k A taking: n rule: step_3_main.induct)
  case (1 k A)
  from 1(2-) have A: "A  carrier_mat n n" and k: "k > 0" by auto
  note [simp] = step_3_main.simps[of n k A]
  show ?case
  proof (cases "k  n")
    case True
    thus ?thesis using similar_mat_refl[OF A] by simp
  next
    case False
    hence kn: "k < n" by simp
    obtain B where B: "B = step_3_a (k - 1) k A" by auto
    note IH = 1(1)[OF False B]
    from A have Bn: "B  carrier_mat n n" unfolding B carrier_mat_def by simp
    from k have "k - 1 < k" by simp
    from step_3_a_similar[OF A this kn] have BA: "similar_mat B A" unfolding B .
    obtain all_blocks where ab: "all_blocks = identify_blocks B k" by simp
    obtain blocks where blocks: "blocks = filter (λ block. B $$ (snd block, k)  0) all_blocks" by simp
    obtain F where F: "F = (if blocks = [] then B
       else let (l_begin,l) = find_largest_block (hd blocks) (tl blocks); x = B $$ (l, k); C = step_3_c x l k blocks B;
            D = mult_col_div_row (inverse x) k C; E = swap_cols_rows_block (Suc l) k D
        in E)" by simp
    note IH = IH[OF ab blocks F]
    have Fn: "F  carrier_mat n n" unfolding F Let_def carrier_mat_def using Bn 
      by (simp split: prod.splits)
    have FB: "similar_mat F B" 
    proof (cases "blocks = []")
      case True
      thus ?thesis unfolding F using similar_mat_refl[OF Bn] by simp
    next
      case False
      obtain l_start l where l: "find_largest_block (hd blocks) (tl blocks) = (l_start, l)" by force
      obtain x where x: "x = B $$ (l,k)" by simp
      obtain C where C: "C = step_3_c x l k blocks B" by simp
      obtain D where D: "D = mult_col_div_row (inverse x) k C" by auto
      obtain E where E: "E = swap_cols_rows_block (Suc l) k D" by auto
      from find_largest_block[OF False l] have lb: "(l_start,l)  set blocks"
        and llarge: " i_begin i_end. (i_begin,i_end)  set blocks  l - l_start  i_end - i_begin" by auto
      from lb have x0: "x  0" unfolding blocks x by simp
      {
        fix i_start i_end
        assume "(i_start,i_end)  set blocks"
        hence "(i_start,i_end)  set (identify_blocks B k)" unfolding blocks ab by simp
        from identify_blocks[OF this]
        have "i_end < k" by simp
      } note block_bound = this
      from block_bound[OF lb]
      have lk: "l < k" .
      from False have F: "F = E" unfolding E D C x F l Let_def by simp
      from Bn have Cn: "C  carrier_mat n n" unfolding C carrier_mat_def by simp
      have CB: "similar_mat C B" unfolding C
      proof (rule step_3_c_similar[OF Bn lk kn])
        fix i_begin i_end
        assume i: "(i_begin, i_end)  set blocks"
        from llarge[OF i] block_bound[OF i] 
        show "i_end  k  i_end - i_begin  l" by auto
      qed
      from x0 have "inverse x  0" by simp
      from mult_col_div_row_similar[OF Cn kn this] 
      have DC: "similar_mat D C" unfolding D .
      from Cn have Dn: "D  carrier_mat n n" unfolding D carrier_mat_def by simp
      from lk have "Suc l  k" by auto
      from swap_cols_rows_block_similar[OF Dn kn this] 
      have ED: "similar_mat E D" unfolding E .
      from similar_mat_trans[OF ED similar_mat_trans[OF DC 
        similar_mat_trans[OF CB similar_mat_refl[OF Bn]]]]
      show ?thesis unfolding F .
    qed
    have "0 < Suc k" by simp
    note IH = IH[OF Fn this]
    have id: "step_3_main n k A = step_3_main n (Suc k) F" using kn 
      by (simp add: F Let_def blocks ab B)
    show ?thesis unfolding id
      by (rule similar_mat_trans[OF IH similar_mat_trans[OF FB BA]])
  qed
qed

lemma step_1_similar: "A  carrier_mat n n  similar_mat (step_1 A) A"
  unfolding step_1_def by (rule step_1_main_similar, auto)

lemma step_2_similar: "A  carrier_mat n n  similar_mat (step_2 A) A"
  unfolding step_2_def by (rule step_2_main_similar, auto)

lemma step_3_similar: "A  carrier_mat n n  similar_mat (step_3 A) A"
  unfolding step_3_def by (rule step_3_main_similar, auto)

end


subsection ‹Invariants for Proving that Result is in JNF›
context 
  fixes n :: nat
  and ty :: "'a :: field itself"
begin

(* upper triangular, ensured by precondition and then maintained *)
definition uppert :: "'a mat  nat  nat  bool" where
  "uppert A i j  j < i  A $$ (i,j) = 0" 

(* zeros at places where EVs differ, ensured by step 1 and then maintained *)
definition diff_ev :: "'a mat  nat  nat  bool" where
  "diff_ev A i j  i < j  A $$ (i,i)  A $$ (j,j)  A $$ (i,j) = 0"

(* same EVs are grouped together, ensured by step 2 and then maintained *)
definition ev_blocks_part :: "nat  'a mat  bool" where
  "ev_blocks_part m A   i j k. i < j  j < k  k < m  A $$ (k,k) = A $$ (i,i)  A $$ (j,j) = A $$ (i,i)"

definition ev_blocks :: "'a mat  bool" where
  "ev_blocks  ev_blocks_part n"

text ‹In step 3, there is a separation at which iteration we are.
  The columns left of $k$ will be in JNF, the columns right of $k$ or equal to $k$ will satisfy @{const uppert}, @{const diff_ev}, 
  and @{const ev_blocks}, and the column at $k$ will have one of the following properties, which are ensured in the
  different phases of step 3.›

(* Left of a one is a zero: In a row of shape 0 ... 0 ev 1 0 ... 0 entry, the entry will be 0,
   ensured by step 3 a and then maintained *)
private definition one_zero :: "'a mat  nat  nat  bool" where
  "one_zero A i j  
    (Suc i < j  A $$ (i,Suc i) = 1  A $$ (i,j) = 0)  
    (j < i  A $$ (i,j) = 0) 
    (i < j  A $$ (i,i)  A $$ (j,j)  A $$ (i,j) = 0)"

(* There is exactly one row   0 ... 0 ev 0 ... 0 entry with entry != 0 and that entry is x,
   ensured by step 3 c and then maintained *)
private definition single_non_zero :: "nat  nat  'a  'a mat  nat  nat  bool" where
  "single_non_zero  λ l k x. (λ A i j. (i  {k,l}  A $$ (i,k) = 0)  A $$ (l,k) = x)"

(* There is at exactly one row   0 ... 0 ev 0 ... 0 entry with entry != 0 and that entry is 1,
   ensured by step 3 d and then maintained *)
private definition single_one :: "nat  nat  'a mat  nat  nat  bool" where
  "single_one  λ l k. (λ A i j. (i  {k,l}  A $$ (i,k) = 0)  A $$ (l,k) = 1)"

(* there is at most one row   0 ... 0 ev 0 ... 0 entry with entry != 0 and that entry is 1 and there
   are no zeros between ev and the entry, ensured by step 3 e *)
private definition lower_one :: "nat  'a mat  nat  nat  bool" where
  "lower_one k A i j = (j = k  
    (A $$ (i,j) = 0  i = j  (A $$ (i,j) = 1  j = Suc i  A $$ (i,i) = A $$ (j,j))))"

(* the desired property: we have a jordan block matrix *)
definition jb :: "'a mat  nat  nat  bool" where
  "jb A i j  (Suc i = j  A $$ (i,j)  {0,1}) 
   (i  j  (Suc i  j  A $$ (i,i)  A $$ (j,j))  A $$ (i,j) = 0)"

text ‹The following properties are useful to easily ensure the above invariants 
  just from invariants of other matrices. The properties are essential in showing
  that the blocks identified in step 3b are the same as one would identify for
  the matrices in the upcoming steps 3c and 3d.›
 
definition same_diag :: "'a mat  'a mat  bool" where
  "same_diag A B   i < n. A $$ (i,i) = B $$ (i,i)"

private definition same_upto :: "nat  'a mat  'a mat  bool" where
  "same_upto j A B   i' j'. i' < n  j' < j  A $$ (i',j') = B $$ (i',j')"

text ‹Definitions stating where the properties hold›

definition inv_all :: "('a mat  nat  nat  bool)  'a mat  bool" where
  "inv_all p A   i j. i < n  j < n  p A i j"

private definition inv_part :: "('a mat  nat  nat  bool)  'a mat  nat  nat  bool" where
  "inv_part p A m_i m_j   i j. i < n  j < n  j < m_j  j = m_j  i  m_i  p A i j"

private definition inv_upto :: "('a mat  nat  nat  bool)  'a mat  nat  bool" where
  "inv_upto p A m   i j. i < n  j < n  j < m  p A i j"

private definition inv_from :: "('a mat  nat  nat  bool)  'a mat  nat  bool" where
  "inv_from p A m   i j. i < n  j < n  j > m  p A i j"

private definition inv_at :: "('a mat  nat  nat  bool)  'a mat  nat  bool" where
  "inv_at p A m   i. i < n  p A i m"

private definition inv_from_bot :: "('a mat  nat  bool)  'a mat  nat  bool" where
  "inv_from_bot p A mi   i. i  mi  i < n  p A i"

text ‹Auxiliary Lemmas on Handling, Comparing, and Accessing Invariants›

lemma jb_imp_uppert: "jb A i j  uppert A i j"
  unfolding jb_def uppert_def by auto

private lemma ev_blocks_partD:
  "ev_blocks_part m A  i < j  j < k  k < m  A $$ (k,k) = A $$ (i,i)  A $$ (j,j) = A $$ (i,i)"
  unfolding ev_blocks_part_def by auto

private lemma ev_blocks_part_leD:
  assumes "ev_blocks_part m A"
  "i  j" "j  k" "k < m" "A $$ (k,k) = A $$ (i,i)" 
  shows "A $$ (j,j) = A $$ (i,i)"
proof -  
  show ?thesis
  proof (cases "i = j  j = k")
    case False
    with assms(2-3) have "i < j" "j < k" by auto
    from ev_blocks_partD[OF assms(1) this assms(4-)] show ?thesis .
  next
    case True
    thus ?thesis using assms(5) by auto
  qed
qed

private lemma ev_blocks_partI:
  assumes " i j k. i < j  j < k  k < m  A $$ (k,k) = A $$ (i,i)  A $$ (j,j) = A $$ (i,i)"
  shows "ev_blocks_part m A"
  using assms unfolding ev_blocks_part_def by blast

private lemma ev_blocksD:
  "ev_blocks A  i < j  j < k  k < n  A $$ (k,k) = A $$ (i,i)  A $$ (j,j) = A $$ (i,i)"
  unfolding ev_blocks_def by (rule ev_blocks_partD)

private lemma ev_blocks_leD:
  "ev_blocks A  i  j  j  k  k < n  A $$ (k,k) = A $$ (i,i)  A $$ (j,j) = A $$ (i,i)"
  unfolding ev_blocks_def by (rule ev_blocks_part_leD)

lemma inv_allD: "inv_all p A  i < n  j < n  p A i j"
  unfolding inv_all_def by auto

private lemma inv_allI: assumes " i j. i < n  j < n  p A i j"
  shows "inv_all p A" 
  using assms unfolding inv_all_def by blast

private lemma inv_partI: assumes " i j. i < n  j < n  j < m_j  j = m_j  i  m_i  p A i j"
  shows "inv_part p A m_i m_j"
  using assms unfolding inv_part_def by auto

private lemma inv_partD: assumes "inv_part p A m_i m_j" "i < n" "j < n" 
  shows "j < m_j  p A i j"
  and "j = m_j  i  m_i  p A i j"
  and "j < m_j  j = m_j  i  m_i  p A i j"
  using assms unfolding inv_part_def by auto

private lemma inv_uptoI: assumes " i j. i < n  j < n  j < m  p A i j"
  shows "inv_upto p A m"
  using assms unfolding inv_upto_def by auto

private lemma inv_uptoD: assumes "inv_upto p A m" "i < n" "j < n" "j < m"
  shows "p A i j"
  using assms unfolding inv_upto_def by auto

private lemma inv_upto_Suc: assumes "inv_upto p A m"
  and " i. i < n  p A i m"
  shows "inv_upto p A (Suc m)"
proof (intro inv_uptoI)
  fix i j
  assume "i < n" "j < n" "j < Suc m"
  thus "p A i j" using inv_uptoD[OF assms(1), of i j] assms(2)[of i] by (cases "j = m", auto)
qed

private lemma inv_upto_mono: assumes " i j. i < n  j < k  p A i j  q A i j"
  shows "inv_upto p A k  inv_upto q A k"
  using assms unfolding inv_upto_def by auto

private lemma inv_fromI: assumes " i j. i < n  j < n  j > m  p A i j"
  shows "inv_from p A m"
  using assms unfolding inv_from_def by auto

private lemma inv_fromD: assumes "inv_from p A m" "i < n" "j < n" "j > m"
  shows "p A i j"
  using assms unfolding inv_from_def by auto

private lemma inv_atI[intro]: assumes " i. i < n  p A i m"
  shows "inv_at p A m"
  using assms unfolding inv_at_def by auto

private lemma inv_atD: assumes "inv_at p A m" "i < n"
  shows "p A i m"
  using assms unfolding inv_at_def by auto
  
private lemma inv_all_imp_inv_part: "m i  n  m_j  n  inv_all p A  inv_part p A m_i m_j"
  unfolding inv_all_def inv_part_def by auto

private lemma inv_all_eq_inv_part: "inv_all p A = inv_part p A n n"
  unfolding inv_all_def inv_part_def by auto

private lemma inv_part_0_Suc: "m_j < n  inv_part p A 0 m_j = inv_part p A n (Suc m_j)"
  unfolding inv_part_def by (auto, case_tac "j = m_j", auto)

private lemma inv_all_uppertD: "inv_all uppert A  j < i  i < n  A $$ (i,j) = 0"
  unfolding inv_all_def uppert_def by auto

private lemma inv_all_diff_evD: "inv_all diff_ev A  i < j  j < n 
   A $$ (i, i)  A $$ (j, j)  A $$ (i,j) = 0"
  unfolding inv_all_def diff_ev_def by auto

private lemma inv_all_diff_ev_uppertD: assumes "inv_all diff_ev A"
  "inv_all uppert A"
  "i < n" "j < n"
  and neg: "A $$ (i, i)  A $$ (j, j)"
  shows "A $$ (i,j) = 0"
proof -
  from neg have "i  j" by auto
  hence "i < j  j < i" by arith
  thus ?thesis
  proof 
    assume "i < j"
    from inv_all_diff_evD[OF assms(1) this j < n neg] show ?thesis .
  next
    assume "j < i"
    from inv_all_uppertD[OF assms(2) this i < n] show ?thesis .
  qed
qed

private lemma inv_from_bot_step: "p A i  inv_from_bot p A (Suc i)  inv_from_bot p A i"
  unfolding inv_from_bot_def by (auto, case_tac "ia = i", auto)

private lemma same_diag_refl[simp]: "same_diag A A" unfolding same_diag_def by auto
private lemma same_diag_trans: "same_diag A B  same_diag B C  same_diag A C" 
  unfolding same_diag_def by auto

private lemma same_diag_ev_blocks: "same_diag A B  ev_blocks A  ev_blocks B"
  unfolding same_diag_def ev_blocks_def ev_blocks_part_def by auto

private lemma same_uptoI[intro]: assumes " i' j'. i' < n  j' < j  A $$ (i',j') = B $$ (i',j')"
  shows "same_upto j A B"
  using assms unfolding same_upto_def by blast

private lemma same_uptoD[dest]: assumes "same_upto j A B" "i' < n" "j' < j" 
  shows "A $$ (i',j') = B $$ (i',j')"
  using assms unfolding same_upto_def by blast

private lemma same_upto_refl[simp]: "same_upto j A A" unfolding same_upto_def by auto

private lemma same_upto_trans: "same_upto j A B  same_upto j B C  same_upto j A C" 
  unfolding same_upto_def by auto

private lemma same_upto_inv_upto_jb: "same_upto j A B  inv_upto jb A j  inv_upto jb B j"
  unfolding inv_upto_def same_upto_def jb_def by auto

lemma jb_imp_diff_ev: "jb A i j  diff_ev A i j"
  unfolding jb_def diff_ev_def by auto

private lemma ev_blocks_diag: 
  "same_diag A B  ev_blocks B  ev_blocks A"
  unfolding ev_blocks_def ev_blocks_part_def same_diag_def by auto

private lemma inv_all_imp_inv_from: "inv_all p A  inv_from p A k"
  unfolding inv_all_def inv_from_def by auto

private lemma inv_all_imp_inv_at: "inv_all p A  k < n  inv_at p A k"
  unfolding inv_all_def inv_at_def by auto

private lemma inv_from_upto_at_all: 
  assumes "inv_upto jb A k" "inv_from diff_ev A k" "inv_from uppert A k" "inv_at p A k"
  and " i. i < n  p A i k  diff_ev A i k  uppert A i k"
  shows "inv_all diff_ev A" "inv_all uppert A"
proof -
  {
    fix i j
    assume ij: "i < n" "j < n"
    have "diff_ev A i j  uppert A i j"
    proof (cases "j < k")
      case True
      with assms(1) ij have "jb A i j" unfolding inv_upto_def by auto
      thus ?thesis using ij unfolding jb_def diff_ev_def uppert_def by auto
    next
      case False note ge = this
      show ?thesis
      proof (cases "j = k")
        case True
        with assms(4-) ij show ?thesis unfolding inv_at_def by auto
      next
        case False
        with ge have "j > k" by auto
        with assms(2-3) ij show ?thesis unfolding inv_from_def by auto
      qed
    qed
  }
  thus "inv_all diff_ev A" "inv_all uppert A" unfolding inv_all_def by auto
qed

private lemma lower_one_diff_uppert:
  "i < n  lower_one k B i k  diff_ev B i k  uppert B i k"
   unfolding lower_one_def diff_ev_def uppert_def by auto

definition ev_block :: "nat  'a mat  bool" where 
  " n. ev_block n A = ( i j. i < n  j < n  A $$ (i,i) = A $$ (j,j))"

lemma ev_blockD: " n. ev_block n A  i < n  j < n  A $$ (i,i) = A $$ (j,j)"
  unfolding ev_block_def carrier_mat_def by blast

lemma same_diag_ev_block: "same_diag A B  ev_block n A  ev_block n B"
  unfolding ev_block_def carrier_mat_def same_diag_def by metis


subsection ‹Alternative Characterization of @{const identify_blocks} in Presence of @{const ev_block}

private lemma identify_blocks_main_iff: assumes *: "k  k'" 
  "k'  k  k > 0  A $$ (k - 1, k)  1" and "k' < n"
  shows "set (identify_blocks_main A k list) = 
  set list  {(i,j) | i j. i  j  j < k  ( l. i  l  l < j  A $$ (l, Suc l) = 1)
   (Suc j  k'  A $$ (j, Suc j)  1)  (i > 0  A $$ (i - 1, i)  1)}" (is "_ = _  ?ss A k")
  using *
proof (induct A k list rule: identify_blocks_main.induct)
  case 1
  show ?case unfolding identify_blocks_main.simps by auto
next
  case (2 A i_e list)
  let ?s = "?ss A"
  obtain i_b where id: "identify_block A i_e = i_b" by force
  note IH = 2(1)[OF id[symmetric]]
  let ?res = "identify_blocks_main A (Suc i_e) list"  
  let ?rec = "identify_blocks_main A i_b ((i_b, i_e) # list)"
  note idb = identify_block[OF id]
  hence res: "?res = ?rec" using id by simp
  from 2(2-) have iek: "i_e < k'" by simp
  from identify_block_le'[OF id] have ibe: "i_b  i_e" .
  from ibe iek have "i_b  k'" by simp
  have "k'  i_b  0 < i_b   A $$ (i_b - 1, i_b)  1" 
    using idb(2) by auto
  note IH = IH[OF i_b  k' this]
  have cong: " a b c d. insert a c = d  set (a # b)  c = set b  d" by auto
  show ?case unfolding res IH
  proof (rule cong)
    from ibe have "?s i_b  ?s (Suc i_e)" by auto
    moreover 
    have inter: "l. i_b  l  l < i_e  A $$ (l, Suc l) = 1" using idb by blast
    have last: "Suc i_e  k'  A $$ (i_e, Suc i_e)  1" using 2(3) by auto
    have "(i_b, i_e)  ?s (Suc i_e)" using ibe idb(2) inter last by blast
    ultimately have "insert (i_b, i_e) (?s i_b)  ?s (Suc i_e)" by auto
    moreover  
    {
      fix i j
      assume ij: "(i,j)  ?s (Suc i_e)"
      hence "(i,j)  insert (i_b, i_e) (?s i_b)"
      proof (cases "j < i_b")
        case True
        with ij show ?thesis by blast
      next
        case False
        with ij have "i_b  j" "j  i_e" by auto
        {
          assume j: "j < i_e"
          from idb(3)[OF i_b  j this] have 1: "A $$ (j, Suc j) = 1" .
          from j ‹Suc i_e  k' have "Suc j  k'" by auto
          with ij 1 have False by auto
        }
        with j  i_e have j: "j = i_e" by (cases "j = i_e", auto)
        {
          assume i: "i < i_b  i > i_b"
          hence False
          proof
            assume "i < i_b"
            hence "i_b > 0" by auto
            with idb(2) have *: "A $$ (i_b - 1, i_b)  1" by auto
            from i < i_b i_b  i_e i_e < k' have "i  i_b - 1" "i_b - 1  k'" by auto
            from i < i_b i_b  i_e j have **: "i  i_b - 1" "i_b - 1 < j" "Suc (i_b - 1) = i_b" by auto
            from ij have " l. li  l < j  A $$ (l, Suc l) = 1" by auto
            from this[OF **(1-2)] **(3) * show False by auto
          next
            assume "i > i_b"
            with ij j have "A $$ (i - 1, i)  1" and 
              *: "i - 1  i_b" "i - 1  i_e" "i - 1 < i_e" "Suc (i - 1) = i" by auto
            with idb(3)[OF *(1,3)] show False by auto
          qed
        }
        hence i: "i = i_b" by arith
        show ?thesis unfolding i j by simp
      qed
    }
    hence "?s (Suc i_e)  insert (i_b, i_e) (?s i_b)" by blast
    ultimately 
    show "insert (i_b, i_e) (?s i_b) = ?s (Suc i_e)" by blast
  qed
qed
  

private lemma identify_blocks_iff: assumes "k < n"
  shows "set (identify_blocks A k) = 
  {(i,j) | i j. i  j  j < k  ( l. i  l  l < j  A $$ (l, Suc l) = 1)
   (Suc j  k  A $$ (j, Suc j)  1)  (i > 0  A $$ (i - 1, i)  1)}" 
  unfolding identify_blocks_def using identify_blocks_main_iff[OF le_refl _ k < n] by auto

private lemma identify_blocksD: assumes "k < n" and "(i,j)  set (identify_blocks A k)"
  shows "i  j" "j < k" 
  " l. i  l  l < j  A $$ (l, Suc l) = 1"
  "Suc j  k  A $$ (j, Suc j)  1"
  "i > 0  A $$ (i - 1, i - 1)  A $$ (k,k)  A $$ (i - 1, i)  1" 
  using assms unfolding identify_blocks_iff[OF assms(1)] by auto

private lemma identify_blocksI: assumes inv: "k < n"
  "i  j" "j < k" " l. i  l  l < j  A $$ (l, Suc l) = 1"
  "Suc j  k  A $$ (j, Suc j)  1" "i > 0  A $$ (i - 1, i)  1"
  shows "(i,j)  set (identify_blocks A k)"
  unfolding identify_blocks_iff[OF inv(1)] using inv by blast

private lemma identify_blocks_rev: assumes "A $$ (i, Suc i) = 0  Suc i < k  Suc i = k"
  and inv: "k < n"
  shows "(identify_block A i, i)  set (identify_blocks A k)"
proof -
  obtain j where id: "identify_block A i = j" by force
  note idb = identify_block[OF this]
  show ?thesis unfolding id 
    by (rule identify_blocksI[OF inv], insert idb assms, auto)
qed


subsection ‹Proving the Invariants›

private lemma add_col_sub_row_diag: assumes A: "A  carrier_mat n n"
  and ut: "inv_all uppert A"
  and ijk: "i < j" "j < n" "k < n"
  shows "add_col_sub_row a i j A $$ (k,k) = A $$ (k,k)"
proof -
  from inv_all_uppertD[OF ut]
  show ?thesis
    by (subst add_col_sub_index_row, insert A ijk, auto)
qed

private lemma add_col_sub_row_diff_ev_part_old: assumes A: "A  carrier_mat n n"
  and ij: "i  j" "i  0" "i < n" "j < n" "i' < n" "j' < n"
  and choice: "j' < j  j' = j  i'  i"
  and old: "inv_part diff_ev A i j"
  and ut: "inv_all uppert A"
  shows "diff_ev (add_col_sub_row a (i - 1) j A) i' j'"
  unfolding diff_ev_def
proof (intro impI)
  assume ij': "i' < j'"
  let ?A = "add_col_sub_row a (i - 1) j A"
  assume neq: "?A $$ (i',i')  ?A $$ (j',j')"
  from A have dim: "dim_row A = n" "dim_col A = n" by auto
  note utd = inv_all_uppertD[OF ut]
  let ?i = "i - 1"
  have "?i < j" using i  j i  0 i < n by auto
  from utd[OF this j < n] have Aji: "A $$ (j,?i) = 0" by simp
  from add_col_sub_row_diag[OF A ut ?i < j j < n]
  have diag: " k. k < n  ?A $$ (k,k) = A $$ (k,k)" .
  from neq[unfolded diag[OF i' < n] diag[OF j' < n]] 
  have neq: "A $$ (i',i')  A $$ (j',j')" by auto
  {
    from inv_partD(3)[OF old i' < n j' < n choice]
    have "diff_ev A i' j'" by auto
    with neq ij' have "A $$ (i',j') = 0" unfolding diff_ev_def by auto
  } note zero = this
  {
    assume "i'  ?i" "j' = j"
    with ij' ij(1) choice have "i' > ?i" by auto
    from utd[OF this] ij
    have "A $$ (i', ?i) = 0" by auto
  } note 1 = this
  {
    assume "j'  j" "i' = ?i"
    with ij' ij(1) choice have "j > j'" by auto
    from utd[OF this] ij
    have "A $$ (j, j') = 0" by auto
  } note 2 = this
  from ij' ij choice have "(i' = ?i  j' = j) = False" by arith
  note id = add_col_sub_index_row[of i' A j' j a ?i, unfolded dim this if_False, 
    OF i' < n i' < n j' < n j' < n j < n]
  show "?A $$ (i',j') = 0" unfolding id zero using 1 2 by auto
qed

private lemma add_col_sub_row_uppert: assumes "A  carrier_mat n n"
  and "i < j"
  and "j < n"
  and inv: "inv_all uppert (A :: 'a mat)"
  shows "inv_all uppert (add_col_sub_row a i j A)"
  unfolding inv_all_def uppert_def
proof (intro allI impI)
  fix i' j'
  assume *: "i' < n" "j' < n" "j' < i'"
  note inv = inv_allD[OF inv, unfolded uppert_def]
  show "add_col_sub_row a i j A $$ (i', j') = 0"
    by (subst add_col_sub_index_row, insert assms * inv, auto)
qed

private lemma step_1_main_inv: "i  j 
   A  carrier_mat n n 
   inv_all uppert A 
   inv_part diff_ev A i j 
   inv_all uppert (step_1_main n i j A)  inv_all diff_ev (step_1_main n i j A)"
proof (induct i j A taking: n rule: step_1_main.induct)
  case (1 i j A)
  let ?i = "i - 1"
  note [simp] = step_1_main.simps[of n i j A]
  from 1(3-) have ij: "i  j" and A: "A  carrier_mat n n" and inv: "inv_all uppert A" 
    "inv_part diff_ev A i j" by auto
  show ?case
  proof (cases "j  n")
    case True
    thus ?thesis using inv by (simp add: inv_all_eq_inv_part, auto simp: inv_part_def)
  next
    case False 
    hence jn: "j < n" by simp
    note IH = 1(1-2)[OF False]
    show ?thesis
    proof (cases "i = 0")
      case True
      from inv[unfolded True inv_part_0_Suc[OF jn]]
      have inv2: "inv_part diff_ev A n (j + 1)" by simp
      have "inv_part diff_ev A (j + 1) (j + 1)" 
      proof (intro inv_partI)
        fix i' j'
        assume ij: "i' < n" "j' < n" and choice: "j' < j + 1  j' = j + 1  j + 1  i'"
        from inv_partD[OF inv2 ij] choice
        show "diff_ev A i' j'" using jn unfolding diff_ev_def by auto
      qed
      from IH(1)[OF True _ A inv(1) this]
      show ?thesis using jn by (simp, simp add: True)
    next
      case False 
      let ?evi = "A $$ (?i,?i)"
      let ?evj = "A $$ (j,j)"
      let ?choice = "?evi  ?evj  A $$ (?i, j)  0"
      let ?A = "add_col_sub_row (A $$ (?i, j) / (?evj - ?evi)) ?i j A"
      let ?B = "if ?choice then ?A else A"
      obtain B where B: "B = ?B" by auto
      have Bn: "B  carrier_mat n n" unfolding B using A by simp
      from False ij jn have *: "?i < j" "j < n" "?i < n" by auto
      have inv1: "inv_all uppert B" unfolding B using inv add_col_sub_row_uppert[OF A *(1-2) inv(1)]
        by auto
      note inv2 = inv_partD[OF inv(2)]
      have inv2: "inv_part diff_ev B ?i j"
      proof (cases ?choice)
        case False
        hence B: "B = A" unfolding B by auto
        show ?thesis unfolding B
        proof (rule inv_partI)
          fix i' j'
          assume ij: "i' < n" "j' < n" and "j' < j  j' = j  ?i  i'"
          hence choice: "(j' < j  j' = j  i  i')  j' = j  i' = ?i" by auto
          note inv2 = inv2[OF ij]
          from choice
          show "diff_ev A i' j'"
          proof
            assume "j' < j  j' = j  i  i'"
            from inv2(3)[OF this] show ?thesis .
          next
            assume "j' = j  i' = ?i"
            thus ?thesis using False unfolding diff_ev_def by auto
          qed
        qed
      next
        case True
        hence B: "B = ?A" unfolding B by auto
        from * True have "i < n" by auto
        note old = add_col_sub_row_diff_ev_part_old[OF A i  j i  0 i < n j < n 
          _ _ _ inv(2) inv(1)]
        show ?thesis unfolding B
        proof (rule inv_partI)
          fix i' j'
          assume ij: "i' < n" "j' < n" and "j' < j  j' = j  ?i  i'"
          hence choice: "(j' < j  j' = j  i  i')  j' = j  i' = ?i" by auto
          note inv2 = inv2[OF ij]
          from choice
          show "diff_ev ?A i' j'"
          proof
            assume "j' < j  j' = j  i  i'"
            from old[OF ij this] show ?thesis .
          next
            assume "j' = j  i' = ?i"
            hence ij': "j' = j" "i' = ?i" by auto
            note diag = add_col_sub_row_diag[OF A inv(1) ?i < j j < n]
            show ?thesis unfolding ij' diff_ev_def diag[OF j < n] diag[OF ?i < n]
            proof (intro impI)
              from True have neq: "?evi  ?evj" by simp
              note ut = inv_all_uppertD[OF inv(1)]
              obtain i' where i': "i' = i - Suc 0" by auto
              obtain diff where diff: "diff = ?evj - A $$ (i',i')" by auto
              from neq have [simp]: "diff  0" unfolding diff i' by auto
              from ut[OF ?i < j j < n] have [simp]: "A $$ (j,i') = 0" unfolding diff i' by simp
              have "?A $$ (?i, j) = 
                A $$ (i', j) + (A $$ (i', j) * A $$ (i', i') -
                A $$ (i', j) * A $$ (j, j)) / diff"
                by (subst add_col_sub_index_row, insert A *, auto simp: diff[symmetric] i'[symmetric] field_simps)
              also have "A $$ (i', j) * A $$ (i', i') - A $$ (i', j) * A $$ (j, j)
                = - A $$ (i',j) * diff" by (simp add: diff i' field_simps)
              also have " / diff = - A $$ (i',j)" by simp
              finally show "?A $$ (?i,j) = 0" by simp
            qed
          qed
        qed
      qed
      from ij have "i - 1  j" by simp
      note IH = IH(2)[OF False refl refl refl refl B this Bn inv1 inv2]
      from False jn have id: "step_1_main n i j A = step_1_main n (i - 1) j B"
        unfolding B by (simp add: Let_def)
      show ?thesis unfolding id by (rule IH)
    qed
  qed
qed

private lemma step_2_main_inv: "A  carrier_mat n n 
   inv_all uppert A 
   inv_all diff_ev A 
   ev_blocks_part j A 
   inv_all uppert (step_2_main n j A)  inv_all diff_ev (step_2_main n j A) 
     ev_blocks (step_2_main n j A)"
proof (induct j A taking: n rule: step_2_main.induct)
  case (1 j A)
  note [simp] = step_2_main.simps[of n j A]
  from 1(2-) have A: "A  carrier_mat n n" 
    and inv: "inv_all uppert A" "inv_all diff_ev A" "ev_blocks_part j A" by auto
  show ?case
  proof (cases "j  n")
    case True
    with inv(3) have "ev_blocks A" unfolding ev_blocks_def ev_blocks_part_def by auto
    thus ?thesis using True inv(1-2) by auto
  next
    case False 
    hence jn: "j < n" by simp
    note intro = ev_blocks_partI
    note dest = ev_blocks_partD
    note IH = 1(1)[OF False]
    let ?look = "lookup_ev (A $$ (j,j)) j A"
    let ?B = "case ?look of 
          None  A
        | Some i  swap_cols_rows_block (Suc i) j A"
    obtain B where B: "B = ?B" by auto
    have id: "step_2_main n j A = step_2_main n (Suc j) B" unfolding B using False by simp
    have Bn: "B  carrier_mat n n" unfolding B using A by (auto split: option.splits)
    have "inv_all uppert B  inv_all diff_ev B  ev_blocks_part (Suc j) B"
    proof (cases ?look)
      case None
      have "ev_blocks_part (Suc j) A"
      proof (intro intro)
        fix i' j' k'
        assume *: "i' < j'" "j' < k'" "k' < Suc j" "A $$ (k',k') = A $$ (i',i')"
        show "A $$ (j',j') = A $$ (i',i')"
        proof (cases "j = k'")
          case False
          with * have "k' < j" by auto
          from dest[OF inv(3) *(1-2) this *(4)]
          show ?thesis .
        next
          case True
          with lookup_ev_None[OF None, of i'] * have False by simp
          thus ?thesis ..
        qed
      qed
      with None show ?thesis unfolding B using inv by auto
    next
      case (Some i)
      from lookup_ev_Some[OF Some] 
      have ij: "i < j" and id: "A $$ (i, i) = A $$ (j, j)" 
        and neq: " k. i < k  k < j  A $$ (k,k)  A $$ (j,j)" by auto
      let ?A = "swap_cols_rows_block (Suc i) j A"
      let ?perm = "λ i'. if i' = Suc i then j else if Suc i < i'  i'  j then i' - 1 else i'"
      from Some have B: "B = ?A" unfolding B by simp
      have Aind: " i' j'. i' < n  j' < n  ?A $$ (i', j') = A $$ (?perm i', ?perm j')"
        by (subst swap_cols_rows_block_index, insert False A ij, auto)
      have inv_ev: "ev_blocks_part (Suc j) ?A"
      proof (intro intro)
        fix i' j' k        
        assume *: "i' < j'" "j' < k" "k < Suc j" and ki: "?A $$ (k,k) = ?A $$ (i',i')" 
        from * jn have "j' < n" "i' < n" "k < n" by auto
        note id' = Aind[OF j' < n j' < n] Aind[OF i' < n i' < n] Aind[OF k < n k < n]
        note inv_ev = dest[OF inv(3)]
        show "?A $$ (j',j') = ?A $$ (i',i')"
        proof (cases "i' < Suc i")
          case True note i' = this
          hence pi: "?perm i' = i'" by simp
          show ?thesis
          proof (cases "j' < Suc i")
            case True note j' = this
            hence pj: "?perm j' = j'" by simp
            show ?thesis
            proof (cases "k < Suc i")
              case True note k = this
              hence pk: "?perm k = k" by simp
              from True ij have "k < j" by simp
              from inv_ev[OF *(1-2) this] ki
              show ?thesis unfolding id' pi pj pk by auto
            next
              case False note kf1 = this
              show ?thesis
              proof (cases "k = Suc i")
                case True note k = this
                hence pk: "?perm k = j" by simp
                from ki id have ii': "A $$ (i, i) = A $$ (i', i')" unfolding id' pi pj pk by simp
                have ji: "A $$ (j',j') = A $$ (i',i')"
                proof (cases "j' = i")
                  case True
                  with ii' show ?thesis by simp
                next
                  case False
                  with j' < Suc i have "j' < i" by auto
                  from ki id inv_ev[OF i' < j' this ij] show ?thesis
                    unfolding id' pi pj pk by simp
                qed
                thus ?thesis unfolding id' pi pj pk .
              next
                case False note kf2 = this
                with kf1 have k: "k > Suc i" by auto
                hence pk: "?perm k = k - 1" and kj: "k - 1 < j"
                  using * k < Suc j by auto
                from k j' have "j' < k - 1" by auto
                from inv_ev[OF *(1) this kj] ki
                show ?thesis unfolding id' pi pj pk by simp
              qed
            qed
          next
            case False note j'f1 = this
            show ?thesis
            proof (cases "j' = Suc i")
              case True note j' = this
              hence pj: "?perm j' = j" by simp
              from j' * have k: "k > Suc i" by auto
              hence pk: "?perm k = k - 1" and kj: "k - 1 < j"
                using * k < Suc j by auto
              from ki[unfolded id' pi pj pk] have eq: "A $$ (k - 1, k - 1) = A $$ (i', i')" .
              from * i' k have le: "i'  i" and lt: "i < k - 1" "k - 1 < j" by auto
              from inv_ev[OF _ lt eq] le have "A $$ (i, i) = A $$ (i', i')" 
                by (cases "i = i'", auto)
              with id show ?thesis unfolding id' pi pj pk by simp
            next
              case False note j'f2 = this
              with j'f1 have "j' > Suc i" by auto
              hence pj: "?perm j' = j' - 1" and pk: "?perm k = k - 1"   
                and kj: "i' < j' - 1" "j' - 1 < k - 1" "k - 1 < j"
                using * i' k < Suc j by auto
              from inv_ev[OF kj] ki
              show ?thesis unfolding id' pi pj pk by simp
            qed
          qed
        next
          case False note i'f1 = this
          show ?thesis
          proof (cases "i' = Suc i")
            case True note i' = this
            with * have gt: "i < k - 1" "k - 1 < j" 
              and perm: "?perm i' = j" "?perm k = k - 1" by auto
            from ki[unfolded id' perm] neq[OF gt] have False by auto
            thus ?thesis ..
          next
            case False note i'f2 = this
            with i'f1 have "i' > Suc i" by auto
            with * have gt: "i' - 1 < j' - 1"  "j' - 1 < k - 1" "k - 1 < j"
              and perm: "?perm i' = i' - 1" "?perm j' = j' - 1" "?perm k = k - 1" by auto
            show ?thesis using inv_ev[OF gt] ki
              unfolding id' perm by simp
          qed
        qed
      qed
      let ?both = "λ A i j. uppert A i j  diff_ev A i j"
      have "inv_all ?both ?A" 
      proof (intro inv_allI)
        fix ii jj
        assume ii: "ii < n" and jj: "jj < n"
        note id = Aind[OF ii ii] Aind[OF jj jj] Aind[OF ii jj]
        note ut = inv_all_uppertD[OF inv(1)]
        note diff = inv_all_diff_evD[OF inv(2)]
        have upper: "uppert ?A ii jj" unfolding uppert_def
        proof
          assume ji: "jj < ii"
          show "?A $$ (ii,jj) = 0" 
          proof (cases "ii < Suc i")
            case True note i = this
            with ji have perm: "?perm ii = ii" "?perm jj = jj" by auto
            show ?thesis unfolding id perm using ut[OF ji ii] .
          next
            case False note if1 = this
            show ?thesis
            proof (cases "ii = Suc i")
              case True note i = this
              with ji ij have perm: "?perm ii = j" "?perm jj = jj" and jj: "jj < j" by auto
              show ?thesis unfolding id perm 
                by (rule ut[OF jj jn])
            next
              case False 
              with if1 have if1: "ii > Suc i" by auto
              show ?thesis
              proof (cases "ii  j")
                case True note i = this
                with if1 have pi: "?perm ii = ii - 1" by auto
                show ?thesis
                proof (cases "jj = Suc i")
                  case True note j = this
                  hence pj: "?perm jj = j" by simp
                  from i ji if1 ii j have ij: "ii - 1 < j" and ii: "i < ii - 1" by auto
                  show ?thesis unfolding id pi pj
                    by (rule diff[OF ij jn neq[OF ii ij]])
                next
                  case False
                  with i ji if1 ii have "?perm jj < ii - 1" "ii - 1 < n" by auto
                  from ut[OF this]
                  show ?thesis unfolding id pi .
                qed
              next
                case False 
                hence i: "ii > j" by auto
                with if1 have pi: "?perm ii = ii" by simp
                from i ji if1 ii have "?perm jj < ii" by auto
                from ut[OF this ii]
                show ?thesis unfolding id pi .
              qed
            qed
          qed
        qed
        have diff: "diff_ev ?A ii jj" unfolding diff_ev_def
        proof (intro impI)
          assume ij': "ii < jj" and neq: "?A $$ (ii,ii)  ?A $$ (jj,jj)"
          show "?A $$ (ii,jj) = 0" 
          proof (cases "jj < Suc i")
            case True note j = this
            with ij' have perm: "?perm ii = ii" "?perm jj = jj" by auto
            show ?thesis using neq unfolding id perm using diff[OF ij' jj] by simp
          next
            case False note jf1 = this
            show ?thesis
            proof (cases "jj = Suc i")
              case True note j = this
              with ij' ij have perm: "?perm jj = j" "?perm ii = ii" and ii: "ii < j" by auto
              show ?thesis using neq unfolding id perm 
                by (intro diff[OF ii jn])
            next
              case False 
              with jf1 have jf1: "jj > Suc i" by auto
              show ?thesis
              proof (cases "jj  j")
                case True note j = this
                with jf1 have pj: "?perm jj = jj - 1" by auto
                show ?thesis
                proof (cases "ii = Suc i")
                  case True note i = this
                  hence pi: "?perm ii = j" by simp
                  from i ij' jf1 jj j have ij: "jj - 1 < j" by auto
                  show ?thesis unfolding id pi pj
                    by (rule ut[OF ij jn])
                next
                  case False
                  with j ij' jf1 jj have "?perm ii < jj - 1" "jj - 1 < n" by auto
                  from diff[OF this] neq
                  show ?thesis unfolding id pj .
                qed
              next
                case False 
                hence j: "jj > j" by auto
                with jf1 have pj: "?perm jj = jj" by simp
                from j ij' jf1 jj have "?perm ii < jj" by auto
                from diff[OF this jj] neq
                show ?thesis unfolding id pj .
              qed
            qed
          qed
        qed
        from upper diff
        show "?both ?A ii jj" ..
      qed
      hence "inv_all diff_ev ?A" "inv_all uppert ?A"
        unfolding inv_all_def by blast+
      with inv_ev show ?thesis unfolding B by auto
    qed
    with IH[OF refl B Bn]
    show ?thesis unfolding id by auto
  qed
qed


private lemma add_col_sub_row_same_upto: assumes "i < j" "j < n" "A  carrier_mat n n" "inv_upto uppert A j"
  shows "same_upto j A (add_col_sub_row v i j A)"
  by (intro same_uptoI, subst add_col_sub_index_row, insert assms, auto simp: uppert_def inv_upto_def)

private lemma add_col_sub_row_inv_from_uppert: assumes *: "inv_from uppert A j"
  and **: "A  carrier_mat n n" "i < n" "i < j" "j < n" 
  shows "inv_from uppert (add_col_sub_row v i j A) j"
proof -
  note * = * **
  let ?A = "add_col_sub_row v i j A"
  show "inv_from uppert ?A j" unfolding inv_from_def
  proof (intro allI impI)
    fix i' j'
    assume **: "i' < n" "j' < n" "j < j'"
    from * ** have "i' < dim_row A" "i' < dim_col A" "j' < dim_row A" "j' < dim_col A" "j < dim_row A" by auto
    note id2 = add_col_sub_index_row[OF this]
    show "uppert ?A i' j'" unfolding uppert_def
    proof (intro conjI impI)
      assume "j' < i'"
      with inv_fromD[OF ‹inv_from uppert A j, unfolded uppert_def, of i' j'] * ** 
      show "?A $$ (i',j') = 0" unfolding id2 using * ** j' < i' by simp
    qed
  qed
qed

private lemma step_3_a_inv: "A  carrier_mat n n 
   i < j  j < n 
   inv_upto jb A j 
   inv_from uppert A j 
   inv_from_bot (λ A i. one_zero A i j) A i 
   ev_block n A
   inv_from uppert (step_3_a i j A) j 
     inv_upto jb (step_3_a i j A) j 
     inv_at one_zero (step_3_a i j A) j  same_diag A (step_3_a i j A)"
proof (induct i j A rule: step_3_a.induct)
  case (1 j A)
  thus ?case by (simp add: inv_from_bot_def inv_at_def)
next
  case (2 i j A)
  from 2(2-) have A: "A  carrier_mat n n" and ij: "Suc i < j" "i < j" and j: "j < n" by auto
  let ?cond = "A $$ (i, i + 1) = 1  A $$ (i, j)  0"
  let ?B = "add_col_sub_row (- A $$ (i, j)) (Suc i) j A"
  obtain B where B: "B = (if ?cond then ?B else A)" by auto
  from A have Bn: "B  carrier_mat n n" unfolding B by simp
  note IH = 2(1)[OF refl B Bn ij(2) j]
  have id: "step_3_a (Suc i) j A = step_3_a i j B" unfolding B by (simp add: Let_def)
  from ij j have *: "Suc i < n" "j < n" "Suc i  j" by auto
  from 2(2-) have inv: "inv_upto jb A j" "inv_from uppert A j" "ev_block n A"
    "inv_from_bot (λA i. one_zero A i j) A (Suc i)"  by auto
  note evbA = ev_blockD[OF inv(3)]
  show ?case
  proof (cases ?cond)
    case False
    hence B: "B = A" unfolding B by auto
    have inv2: "inv_from_bot (λA i. one_zero A i j) A i"
      by (rule inv_from_bot_step[OF _ inv(4)],
      insert False ij evbA[of i j] *, auto simp: one_zero_def)
    show ?thesis unfolding id B
      by (rule IH[unfolded B], insert inv inv2, auto)
  next
    case True
    hence B: "B = ?B" unfolding B by auto
    let ?C = "step_3_a i j B"
    from inv_uptoD[OF inv(1) j *(1) ij(1), unfolded jb_def] ij 
    have Aji: "A $$ (j, Suc i) = 0" by auto
    have diag: "same_diag A B" unfolding same_diag_def
      by (intro allI impI, insert ij j A Aji B, auto)
    have upto: "same_upto j A B" unfolding B
      by (rule add_col_sub_row_same_upto[OF ‹Suc i < j j < n A inv_upto_mono[OF jb_imp_uppert inv(1)]])
    from add_col_sub_row_inv_from_uppert[OF inv(2) A ‹Suc i < n ‹Suc i < j j < n]
    have from_j: "inv_from uppert B j" unfolding B by blast
    have ev: "A $$ (Suc i, Suc i) = A $$ (j,j)" using evbA[of "Suc i" j] ij j by auto
    have evb_B: "ev_block n B"
      by (rule same_diag_ev_block[OF diag inv(3)])
    note evbB = ev_blockD[OF evb_B]
    {
      fix k
      assume "k < n"
      with A * have k: "k < dim_row A" "k < dim_col A" "j < dim_row A" "j < dim_col A" "j < dim_row A" by auto
      note id = B add_col_sub_index_row[OF k]
      have "B $$ (k,j) = (if k = i then 0 else A $$ (k,j))" unfolding id
        using inv_uptoD[OF inv(1), of k "Suc i", unfolded jb_def]
        by (insert * Aji True ij k < n, auto simp: ev)
    } note id2 = this
    have "inv_from_bot (λA i. one_zero A i j) B i" unfolding inv_from_bot_def
    proof (intro allI impI)
      fix k
      assume "i  k" "k < n"
      thus "one_zero B k j" using inv(4)[unfolded inv_from_bot_def]
        upto[unfolded same_upto_def] evbB[OF k < n j < n]  
        unfolding one_zero_def id2[OF k < n] by auto
    qed
    from IH[OF same_upto_inv_upto_jb[OF upto inv(1)] from_j this evb_B]
      same_diag_trans[OF diag]
    show ?thesis unfolding id by blast
  qed
qed

private lemma identify_block_cong: assumes su: "same_upto k A B" and kn: "k < n"
  shows "i < k  identify_block A i = identify_block B i"
proof (induct i)
  case (Suc i)
  hence "i < k" by auto
  note IH = Suc(1)[OF this]
  let ?c = "λ A. A $$ (i,Suc i) = 1"
  from same_uptoD[OF su, of i "Suc i"] kn Suc(2) have 1: "A $$ (i, Suc i) = B $$ (i, Suc i)" by auto
  from 1 have id: "?c A = ?c B" by simp
  show ?case
  proof (cases "?c A")
    case True
    with True[unfolded id] IH show ?thesis by simp
  next
    case False
    with False[unfolded id] show ?thesis by auto
  qed
qed simp

private lemma identify_blocks_main_cong: 
  "k < n  same_upto k A B  identify_blocks_main A k xs = identify_blocks_main B k xs"
proof (induct k arbitrary: xs rule: less_induct)
  case (less k list)
  show ?case
  proof (cases "k = 0")
    case False
    then obtain i_e where k: "k = Suc i_e" by (cases k, auto)
    obtain i_b where idA: "identify_block A i_e = i_b" by force
    from identify_block_le'[OF idA] have ibe: "i_b  i_e" .
    have idB: "identify_block B i_e = i_b" unfolding idA[symmetric]
      by (rule sym, rule identify_block_cong, insert k less(2-3), auto)
    let ?I = "identify_blocks_main"
    let ?resA = "?I A (Suc i_e) list"  
    let ?recA = "?I A i_b ((i_b, i_e) # list)"
    let ?resB = "?I B (Suc i_e) list"  
    let ?recB = "?I B i_b ((i_b, i_e) # list)"
    have res: "?resA = ?recA" "?resB = ?recB" using idA idB by auto
    from k ibe have ibk: "i_b < k" by simp
    with less(3) have "same_upto i_b A B" unfolding same_upto_def by auto
    from less(1)[OF ibk _ this] ibk k < n have "?recA = ?recB" by auto
    thus ?thesis unfolding k res by simp
  qed simp
qed

private lemma identify_blocks_cong: 
  "k < n  same_diag A B  same_upto k A B  identify_blocks A k = identify_blocks B k"
  unfolding identify_blocks_def
  by (intro identify_blocks_main_cong, auto simp: same_diag_def)

private lemma inv_from_upto_at_all_ev_block: 
  assumes jb: "inv_upto jb A k" and ut: "inv_from uppert A k" and at: "inv_at p A k" and evb: "ev_block n A"
  and p: " i. i < n  p A i k  uppert A i k"
  and k: "k < n"
  shows "inv_all uppert A"
proof (rule inv_from_upto_at_all[OF jb _ ut at])
  from ev_blockD[OF evb] 
  show "inv_from diff_ev A k" unfolding inv_from_def diff_ev_def by blast
  fix i
  assume "i < n" "p A i k"
  with ev_blockD[OF evb k, of i] p[OF this] k
  show "diff_ev A i k  uppert A i k"
    unfolding diff_ev_def by auto
qed


text ‹For step 3c, during the inner loop, the invariants are NOT preserved. 
  However, at the end of the inner loop, the invariants are again preserved.
  Therefore, for the inner loop we prove how the resulting matrix looks like in
  each iteration.›
 
private lemma step_3_c_inner_result: assumes inv:
  "inv_upto jb A k"
  "inv_from uppert A k"
  "inv_at one_zero A k"
  "ev_block n A"
  and k: "k < n"
  and A: "A  carrier_mat n n"
  and lbl: "(lb,l)  set (identify_blocks A k)"
  and ib_block: "(i_begin,i_end)  set (identify_blocks A k)"
  and il: "i_end  l"
  and large: "l - lb  i_end - i_begin"
  and Alk: "A $$ (l,k)  0"
  shows "step_3_c_inner_loop (A $$ (i_end, k) / A $$ (l,k)) l i_end (Suc i_end - i_begin) A =
    mat n n
     (λ(i, j). if (i, j) = (i_end, k) then 0
               else if i_begin  i  i  i_end  k < j then A $$ (i, j) - A $$ (i_end, k) / A $$ (l,k) * A $$ (l + i - i_end, j)
                    else A $$ (i, j))" (is "?L = ?R")
proof -
  let ?Alk = "A $$ (l,k)"
  let ?Aik = "A $$ (i_end,k)"
  define quot where "quot = ?Aik / ?Alk"
  let ?idiff = "i_end - i_begin"
  let ?m = "λ iter diff i j. if (i,j) = (i_end,k) then if diff = (Suc ?idiff) then ?Aik else 0
    else if i  i_begin + diff  i  i_end  k < j then A $$ (i, j) - quot * A $$ (l + i - i_end, j)
    else if (i,j) = (i_end - iter, Suc l - iter)  iter  {0, Suc ?idiff} then quot 
    else A $$ (i,j)"
  let ?mm = "λ iter diff i j. if (i,j) = (i_end,k) then 0
    else if i  i_begin + diff  i  i_end  k < j then A $$ (i, j) - quot * A $$ (l + i - i_end, j)
    else if (i,j) = (i_end - Suc iter, l - iter)  iter  ?idiff then quot 
    else A $$ (i,j)"
  let ?mat = "λ iter diff. mat n n (λ (i,j). ?m iter diff i j)"
  from identify_blocks[OF ib_block] have ib: "i_begin  i_end" "i_end < k" by auto
  from identify_blocks[OF lbl] have lb: "lb  l" "l < k" by auto
  have mend: "?mat 0 (Suc ?idiff) = A"
    by (rule eq_matI, insert A ib, auto)
  {
    fix ll ii diff iter 
    assume "diff  0  ii + iter = i_end" "diff  0  ll + iter = l" "diff + iter = Suc ?idiff"
    hence "step_3_c_inner_loop quot ll ii diff (?mat iter diff) = ?R"
    proof (induct diff arbitrary: ii ll iter)
      case 0
      hence iter: "iter = Suc ?idiff" by auto
      have "step_3_c_inner_loop quot ll ii 0 (?mat iter 0) = ?mat (Suc ?idiff) 0" 
        unfolding iter step_3_c_inner_loop.simps ..
      also have " = ?R"
        by (rule eq_matI, insert ib, auto simp: quot_def)
      finally show ?case .
    next
      case (Suc diff ii ll)
      note prems = Suc(2-)
      let ?B = "?mat iter (Suc diff)"
      have "step_3_c_inner_loop quot ll ii (Suc diff) ?B
      = step_3_c_inner_loop quot (ll - 1) (ii - 1) diff (add_col_sub_row quot ii ll ?B)"
        by simp
      also have "add_col_sub_row quot ii ll ?B
        = ?mat (Suc iter) diff" (is "?C = ?D") 
      proof (rule eq_matI, unfold dim_row_mat dim_col_mat)
        fix i j
        assume i: "i < n" and j: "j < n"
        have ll: "ll < n" using prems lb k by auto
        from prems ib k have ii: "ii  i_begin" "ii < n" "ii < k" "ii  i_end"
          and eqs: "ii + iter = i_end" "ll + iter = l" "Suc diff + iter = Suc ?idiff" by auto
        from eqs have diff: "diff < Suc ?idiff" by auto
        from eqs lb k < n have "ll < k" "l < n" by auto
        note index = ib lb k i j ll il large ii this
        let ?Aij = "A $$ (i,j)"
        have D: "?D $$ (i,j) = ?mm iter diff i j" using diff i j by (auto split: if_splits)
        define B where "B = ?B"
        have BB: " i j. i < n  j < n  B $$ (i,j) = ?m iter (Suc diff) i j" unfolding B_def by auto
        have B: "B $$ (i,j) = ?m iter (Suc diff) i j" by (rule BB[OF i j])
        have C: "?C $$ (i, j) =  
         (if i = ii  j = ll then B $$ (i, j) + quot * B $$ (i, i) - quot * quot * B $$ (j, i) - quot * B $$ (j, j)
          else if i = ii  j  ll then B $$ (i, j) - quot * B $$ (ll, j) 
          else if i  ii  j = ll then B $$ (i, j) + quot * B $$ (i, ii) 
          else B $$ (i, j))" unfolding B_def
          by (rule add_col_sub_index_row(1), insert i j ll, auto)
        from inv_from_upto_at_all_ev_block[OF inv(1-4) _ k < n] 
        have invA: "inv_all uppert A"
          unfolding one_zero_def uppert_def by auto
        note ut = inv_all_uppertD[OF invA]
        note jb = inv_uptoD[OF inv(1), unfolded jb_def]
        note oz = inv_atD[OF inv(3), unfolded one_zero_def]
        note evb = ev_blockD[OF inv(4)]
        note iblock = identify_blocksD[OF k < n]
        note ibe = iblock[OF ib_block]
        let ?ev = "λ i. A $$ (i,i)"

        {
          fix i ib ie
          assume "(ib,ie)  set (identify_blocks A k)" and i: "ib  i" "i < ie"
          note ibe = iblock[OF this(1)]
          from ibe(3)[OF i] have id: "A $$ (i, Suc i) = 1" by auto
          from i ibe k < n have "i < n" "Suc i < k" by auto
          with oz[OF this(1)] id
          have "A $$ (i,k) = 0" by auto
        } note A_ik = this

        {
          fix i 
          assume i: "i < n" and "¬ (i  i_begin  i  i_end)"
          hence choice: "i > i_end  i < i_begin" by auto
          note index = index i 
          from index eqs choice have "i  ii" by auto
          {
            assume 0: "A $$ (i,ii)  0"
            from 0 ut[of ii, OF _ i] i  ii have "i < ii" by force
            from choice index eqs this have "i < i_begin" by auto
            with index have "i < k" by auto
            from jb[OF i ii < n ii < k] 0 i  ii 
            have *: "Suc i = ii" "A $$ (i,ii) = 1" "?ev i = ?ev ii" by auto
            with index i < i_begin have "ii = i_begin" by auto
            with evb[OF i < n k < n] ibe(5) * have False by auto
          }
          hence Aii: "A $$ (i,ii) = 0" by auto
          {
            fix j assume j: "j < n"
            have B: "B $$ (i,j) = ?m iter (Suc diff) i j" using i j unfolding B_def by simp
            from choice have id: "((i, j) = (i_end - iter, Suc l - iter)  iter  {0, Suc ?idiff}) = False" 
              using ib index eqs by auto
            have "B $$ (i,j) = A $$ (i,j)" unfolding B id using choice ib index by auto
          }
          note Aii this
        }                     
        hence A_outside_ii: " i. i < n  ¬ (i_begin  i  i  i_end)  A $$ (i, ii) = 0" 
        and B_outside: " i j. i < n  j < n  ¬ (i_begin  i  i  i_end)  B $$ (i, j) = A $$ (i, j)" by auto

        from diff eqs have iter: "iter  Suc ?idiff" by auto
        {
          fix ib ie jb je
          assume i: "(ib,ie)  set (identify_blocks A k)" and 
            j: "(jb,je)  set (identify_blocks A k)" and lt: "ie < je"
          note i = iblock[OF i]
          note j = iblock[OF j]
          from i j lt have "Suc ie < k" by auto
          with i have Aie: "A $$ (ie, Suc ie)  1" by auto
          have "ie < jb" 
          proof (rule ccontr)
            assume "¬ ie < jb"
            hence "ie  jb" by auto
            from j(3)[OF this lt] Aie show False by auto
          qed
        } note block_bounds = this
        {
          assume "i_end < l"
          from block_bounds[OF ib_block lbl this]
          have "i_end < lb" .
        } note i_less_l = this 
        {
          assume "l < i_end"
          from block_bounds[OF lbl ib_block this]
          have "l < i_begin" .
        } note l_less_i = this
        {
          assume "i_end - iter = Suc l - iter"
          with iter large eqs have "i_end = Suc l" by auto
          with l_less_i have "l < i_begin" by auto
          with index i_end = Suc l have "i_begin = i_end" by auto
        } note block = this
        have Alie: "A $$ (l, i_end) = 0" 
        proof (cases "l < i_end")
          case True
          {
            assume nz: "A $$ (l, i_end)  0"
            from l_less_i[OF True] index have "0 < i_begin" "l < i_begin" "i_end < n" "i_end < k" by auto
            from jb[OF l < n this(3-4)] il nz
            have "i_end = Suc l" "A $$ (l, Suc l) = 1" by auto
            with iblock[OF lbl] have "k = Suc l" by auto
            with i_end = Suc l i_end < k have False by auto
          }           
          thus ?thesis by auto
        next
          case False
          with il have "i_end < l" by auto
          from ut[OF this l < n] show ?thesis .
        qed          
        show "?C $$ (i,j) = ?D $$ (i,j)" 
        proof (cases "i  i_begin  i  i_end")
          case False
          hence choice: "i > i_end  i < i_begin" by auto
          from choice have id: "((i, j) = (i_end - Suc iter, l - iter)  iter  ?idiff) = False" 
            using ib index eqs by auto
          have D: "?D $$ (i,j) = ?Aij" unfolding D id using choice ib index by auto
          have B: "B $$ (i,j) = ?Aij" unfolding B_outside[OF i j False] ..
          from index eqs False have "i  ii" by auto
          have Bii: "B $$ (i, ii) = A $$ (i,ii)" unfolding B_outside[OF i ii < n False] ..
          hence C: "?C $$ (i,j) = ?Aij" unfolding C B Bii using i  ii A_outside_ii[OF i False] by auto
          show ?thesis unfolding D C ..
        next
          case True
          with index have "i_begin  i" "i  i_end" "i < k" by auto
          note index = index this
          show ?thesis
          proof (cases "j > k")
            case True 
            note index = index this
            have D: "?D $$ (i,j) = (if i_begin + diff  i then ?Aij - quot * A $$ (l + i - i_end, j) else ?Aij)" unfolding D
              using index by auto
            have B: "B $$ (i,j) = (if i_begin + Suc diff  i then ?Aij - quot * A $$ (l + i - i_end, j) else ?Aij)" unfolding B
              using index by auto
            from index eqs have "j > ll" by auto
            hence C: "?C $$ (i,j) = (if i = ii then B $$ (i, j) - quot * B $$ (ll, j) else B $$ (i, j))" unfolding C
              using index by auto
            show ?thesis
            proof (cases "i_begin + Suc diff  i  ¬ (i_begin + diff  i)")
              case True
              from True eqs index have "i  ii" by auto
              from True have "?D $$ (i,j) = B $$ (i,j)" unfolding D B by auto
              also have "B $$ (i,j) = ?C $$ (i,j)" unfolding C using i  ii by auto
              finally show ?thesis ..
            next
              case False
              hence i: "i = i_begin + diff" by simp
              with eqs index have ii: "ii = i" by auto
              from index eqs i ii have ll: "ll = l + i - i_end" by auto              
              have not: "¬ (i_begin + Suc diff  ll  ll  i_end)" 
              proof
                from eqs have "ll  l" by auto
                assume "i_begin + Suc diff  ll  ll  i_end"
                hence "i_begin < ll" "ll  i_end" by auto
                with ll  l have "i_begin < l" by auto
                with l_less_i have "¬ l < i_end" by auto
                hence "l  i_end" by simp
                with il i_less_l have "i_end < lb" by auto
                from index large eqs have "lb  ll" by auto
                with i_end < lb have "i_end < ll" by auto
                with ll  i_end 
                show False by auto
              qed
              have D: "?D $$ (i,j) = ?Aij - quot * A $$ (ll, j)" unfolding D unfolding i ll by simp
              have C: "?C $$ (i,j) = ?Aij - quot * B $$ (ll, j)" unfolding C B unfolding ii i by simp
              have B: "B $$ (ll, j) = A $$ (ll, j)" unfolding BB[OF ll < n j] using index not by auto
              show ?thesis unfolding C D B unfolding ii i by (simp split: if_splits)
            qed
          next 
            case False
            hence "j < k  j = k" by auto
            thus ?thesis
            proof
              assume jk: "j = k"
              hence "j  Suc l - Suc iter" using index by auto
              hence "?D $$ (i,j) = (if i = i_end then 0 else ?Aij)" unfolding D using jk by auto
              also have " = 0" using A_ik[OF ib_block i_begin  i] i  i_end  unfolding jk by auto
              finally have D: "?D $$ (i,j) = 0" .
              from jk index have "j  ll" by auto
              hence C: "?C $$ (i,j) = (if i = ii then B $$ (i, j) - quot * B $$ (ll, j) else B $$ (i, j))" 
                unfolding C unfolding jk by simp
              have C: "?C $$ (i,j) = 0"
              proof (cases "i = i_end")
                case False
                with index ii jk have i: "i_begin  i" "i < i_end" by auto
                from A_ik[OF ib_block this] have Aij: "A $$ (i,j) = 0" unfolding jk .
                from index i jk have "¬ ((i, j) = (i_end - iter, Suc l - iter)  iter  {0, Suc ?idiff})" by auto
                hence Bij: "B $$ (i,j) = 0" 
                  unfolding B Aij using i jk by auto
                hence C: "?C $$ (i,j) = (if i = ii then - quot * B $$ (ll,j) else 0)" unfolding C by auto
                let ?l = "l - iter"
                from index eqs have ll: "ll = ?l" by auto
                show "?C $$ (i,j) = 0"
                proof (cases "i = ii")
                  case True
                  with index eqs i have l: "lb  ?l" "?l < l" and diff: "Suc diff  Suc ?idiff" by auto
                  from A_ik[OF lbl l] have Alj: "A $$ (ll,j) = 0" unfolding jk ll .
                  from index l jk eqs have "¬ ((ll, j) = (i_end - iter, Suc l - iter)  iter  {0, Suc ?idiff})" by auto
                  hence Bij: "B $$ (ll,j) = 0" unfolding BB[OF ll < n j] Alj
                    using l jk diff by auto
                  thus ?thesis unfolding C by simp
                next
                  case False
                  thus ?thesis unfolding C by simp
                qed
              next
                case True note i = this
                hence Bij: "B $$ (i,j) = (if diff = ?idiff then A $$ (i_end, k) else 0)" unfolding B unfolding jk by auto
                show ?thesis
                proof (cases "i = ii")
                  case True
                  with i eqs have "diff = ?idiff" "ll = l" "iter = 0" by auto
                  hence B: "B $$ (i,j) = A $$ (i_end,k)" unfolding Bij by auto
                  have C: "?C $$ (i,j) = A $$ (i_end,k) - quot * B $$ (l, k)" 
                    unfolding C B unfolding True ll =l jk by simp
                  also have "B $$ (l,k) = A $$ (l,k)"
                    unfolding BB[OF l < n k < n] using il iter = 0 by auto
                  also have "A $$ (i_end,k) - quot *  = 0" unfolding quot_def using Alk by auto
                  finally show ?thesis .
                next
                  case False
                  with i eqs have "diff  ?idiff" by auto
                  thus ?thesis unfolding C Bij using False by auto
                qed
              qed
              show ?thesis unfolding C D ..
            next
              assume jk: "j < k"
              from eqs il have "ii  ll" by auto
              show ?thesis
              proof (cases "diff = 0  (i,j)  (ii - 1,ll)")
                case False
                with eqs have **: "i = i_end - Suc iter" "j = l - iter" "iter  ?idiff" 
                  and *: "diff  0" "i = ii - 1" "j = ll" "ii  0" "i  ii" by auto
                hence D: "?D $$ (i,j) = quot" unfolding D using jk index by auto
                from * index eqs False jk have i: "ii = Suc i" "i < i_end" by auto
                from iblock(3)[OF ib_block i_begin  i i < i_end] 
                have Ai: "A $$ (i, ii) = 1" unfolding i .
                have "ii < k" "i  i_end - iter" using index * ** eqs
                  by (blast, force)
                hence Bi: "B $$ (i,ii) = 1" unfolding BB[OF i < n ii < n] Ai by auto
                have "B $$ (i,ll) = A $$ (i,ll)" unfolding BB[OF i < n ll < n] 
                  using i  i_end - iter ll < k by auto
                also have "A $$ (i,ll) = 0"
                proof (rule ccontr)
                  assume nz: "A $$ (i,ll)  0"
                  from i eqs il have neq: "Suc i  ll" by auto
                  from jb[OF i < n ll < n ll < k] nz neq 
                  have "i = ll" by auto
                  with i have "ii = Suc ll" by simp
                  hence "i_end - iter = Suc l - iter" using eqs by auto
                  from block[OF this] have "i_begin = i_end" by auto
                  with large ib lb index have "i = ii" by auto
                  with * show False by auto
                qed
                finally have C: "?C $$ (i,j) = quot" unfolding C using * Bi by auto
                show ?thesis unfolding C D ..
              next
                case True
                with eqs have "¬ ((i, j) = (i_end - Suc iter, l - iter)  iter  ?idiff)" 
                  and not: "¬ (i = ii - 1  j = ll  iter  ?idiff)" by auto
                from this(1) have D: "?D $$ (i,j) = ?Aij" unfolding D using jk index by auto
                {
                  fix i
                  assume "i < n"
                  with index have id: "((i,i) = (i_end,k)) = False" "(i_begin + Suc diff  i  i  i_end  k < i) = False" by auto
                  {
                    assume *: "(i, i) = (i_end - iter, Suc l - iter)  iter  {0, Suc ?idiff}"
                    hence "i_end - iter = Suc l - iter" by auto
                    from block[OF this] * index large eqs have False by auto
                  }                    
                  hence "B $$ (i,i) = ?ev i" unfolding BB[OF i < n i < n] id if_False by auto
                } note Bdiag = this
                from eqs have ii: "ii = i_end - iter" "Suc l - iter = Suc ll" by auto
                have B: "B $$ (i,j) = 
                  (if (i, j) = (ii, Suc ll)  iter  0 then quot else A $$ (i, j))" 
                  unfolding B using ii jk iter by auto
                have ll_i: "ll  i_end - iter" using ii  ll eqs by auto
                have "B $$ (ll,ii) = A $$ (ll,ii)" unfolding BB[OF ll < n ii < n]
                  using ii < k ll_i by auto
                also have " = 0"
                proof (rule ccontr)
                  assume nz: "A $$ (ll,ii)  0"
                  with jb[OF ll < n ii < n ii < k] ii  ll have "Suc ll = ii" by auto
                  with eqs have "i_end - iter = Suc l - iter" by auto
                  from block[OF this] index eqs have "iter = 0" by auto
                  with ii have "ll = l" "ii = i_end" by auto
                  with Alie nz show False by auto
                qed
                finally have Bli: "B $$ (ll,ii) = 0" .
                have C: "?C $$ (i,j) = ?Aij" 
                proof (cases "i = j")
                  case True
                  show ?thesis unfolding C unfolding Bdiag[OF j < n] True using ii  ll Bli
                    by auto
                next
                  case False
                  from lb eqs index large have "lb  ll" "ll  l" by auto
                  note C = C[unfolded Bdiag[OF i < n] Bdiag[OF j < n]]
                  show ?thesis 
                  proof (cases "(i, j) = (ii, Suc ll)  iter  0")
                    case True
                    hence *: "i = ii" "j = Suc ll" "iter  0" by auto
                    from * eqs index have "ll < l" "Suc ll < k" "Suc ll < n" by auto
                    have B: "B $$ (i,j) = quot" unfolding B using * by auto
                    have "¬ ((ll, j) = (i_end - iter, Suc l - iter)  iter  {0, Suc ?idiff})"
                      using * index eqs by auto
                    hence B': "B $$ (ll, j) = A $$ (ll, j)" 
                      unfolding BB[OF ll < n j < n] using jk by auto
                    have "?C $$ (i,j) = quot - quot * A $$ (ll, Suc ll)" unfolding C B using * B' by auto
                    with iblock(3)[OF lbl lb  ll ll < l] have C: "?C $$ (i,j) = 0" by simp
                    {
                      assume "A $$ (ii, Suc ll)  0"
                      with jb[OF ii < n ‹Suc ll < n ‹Suc ll < k] ii  ll
                      have "ii = Suc ll" by auto
                      with eqs have "i_end - iter = Suc l - iter" by auto
                      from block[OF this] iter  0 iter  Suc ?idiff eqs large have False by auto
                    }
                    hence "A $$ (ii,Suc ll) = 0" by auto
                    thus ?thesis unfolding C unfolding * by simp
                  next
                    case False
                    hence B: "B $$ (i,j) = ?Aij" unfolding B by auto
                    from eqs index have "lb  ll" "ll  l" by auto
                    note index = index this ll_i
                    from evb[of ll k] index have evl: "?ev ll = ?ev k" by auto
                    from evb[of i k] index have evi: "?ev i = ?ev k" by auto
                    from not have not: "i  ii - 1  j  ll  iter = ?idiff" by auto
                    from False have not2: "i  ii  j  Suc ll  iter = 0" by auto
                    show ?thesis
                    proof (cases "i = ii")
                      case True
                      let ?diff = "if j = ll then 0 else - quot * A $$ (ll, j)"
                      have Bli: "B $$ (ll, i) = 0" using True Bli by simp
                      have Blj: "B $$ (ll, j) = A $$ (ll,j)" unfolding BB[OF ll < n j < n] 
                        using index jk by auto
                      from True have C: "?C $$ (i,j) = ?Aij + ?diff" 
                        unfolding C B evi using Bli Blj evl by auto
                      also have "?diff = 0" 
                      proof (rule ccontr)
                        assume "?diff  0"
                        hence jl: "j  ll" and Alj: "A $$ (ll,j)  0" by (auto split: if_splits)
                        with jb[OF ll < n j < n jk] have "j = Suc ll" "?ev ll = ?ev j" by auto
                        with not2 True have "iter = 0" by auto
                        with eqs index jk have id: "A $$ (ll, j) = A $$ (l, Suc l)" and 
                          "j = Suc l" "Suc l < k" "ll = l" 
                          unfolding j = Suc ll by auto
                        from iblock[OF lbl] ‹Suc l < k have "A $$ (l, Suc l)  1" by auto
                        from jb[OF l < n j < n jk] Alj this show False unfolding j = Suc l ll = l by auto
                      qed
                      finally show ?thesis by simp
                    next
                      case False
                      let ?diff = "if j = ll then quot * B $$ (i, ii) else 0"
                      from False have C: "?C $$ (i,j) = ?Aij + ?diff"
                        unfolding C B by auto
                      also have "?diff = 0" 
                      proof (rule ccontr)
                        assume "?diff  0"
                        hence j: "j = ll" and Bi: "B $$ (i, ii)  0" by (auto split: if_splits)
                        from eqs have ii: "i_end - iter = ii" by auto
                        have Bii: "B $$ (i,ii) = A $$ (i, ii)"
                          unfolding BB[OF i < n ii < n] using ii < k iter ii False by auto
                        from Bi Bii have Ai: "A $$ (i,ii)  0" by auto
                        from jb[OF i < n ii < n ii < k] False Ai have ii: "ii = Suc i" 
                          and Ai: "A $$ (i,ii) = 1" by auto
                        from not ii j have iter: "iter = ?idiff" by auto
                        with eqs index have "ii = i_begin" by auto
                        with ii i  i_begin 
                        show False by auto
                      qed
                      finally show ?thesis by simp
                    qed
                  qed
                qed    
                show ?thesis unfolding D C ..
              qed
            qed
          qed
        qed
      qed auto
      also have "step_3_c_inner_loop quot (ll - 1) (ii - 1) diff  = ?R"
        by (rule Suc(1), insert prems large, auto)
      finally show ?case .
    qed
  }
  note main = this[of "Suc ?idiff" i_end 0 l]
  from ib have suc: "Suc i_end - i_begin = Suc ?idiff" by simp
  have "step_3_c_inner_loop (A $$ (i_end, k) / A $$ (l, k)) l i_end (Suc ?idiff) A
    = step_3_c_inner_loop quot l i_end (Suc ?idiff) (?mat 0 (Suc ?idiff))"
    unfolding mend unfolding quot_def ..
  also have " = ?R" by (rule main, auto)
  finally show ?thesis unfolding suc .
qed

private lemma step_3_c_inv: "A  carrier_mat n n 
   k < n 
   (lb,l)  set (identify_blocks A k)
   inv_upto jb A k
   inv_from uppert A k
   inv_at one_zero A k
   ev_block n A
   set bs  set (identify_blocks A k)
   ( be. be  snd ` set bs  be  {l,k}  be < n  A $$ (be,k) = 0)
   ( bb be. (bb,be)  set bs  be - bb  l - lb) ― ‹largest block›
   x = A $$ (l,k)
   x  0
   inv_all uppert (step_3_c x l k bs A)
     same_diag A (step_3_c x l k bs A) 
     same_upto k A (step_3_c x l k bs A)
     inv_at (single_non_zero l k x) (step_3_c x l k bs A) k"
proof (induct bs arbitrary: A) 
  case (Nil A)
  note inv = Nil(4-7)
  from inv_from_upto_at_all_ev_block[OF inv(1-4) _ k < n]
  have "inv_all uppert A" unfolding one_zero_def diff_ev_def uppert_def by auto
  moreover 
  have "inv_at (single_non_zero l k x) A k" unfolding single_non_zero_def inv_at_def
    by (intro allI impI conjI, insert Nil, auto)
  ultimately show ?case by auto
next
  case (Cons p bs A)
  obtain i_begin i_end where p: "p = (i_begin, i_end)" by force
  note Cons = Cons[unfolded p]
  note IH = Cons(1)
  note A = Cons(2)
  note kn = Cons(3)
  note lbl = Cons(4)
  note inv = Cons(5-8)
  note blocks = Cons(9-11)
  note x = Cons(12-13)
  from identify_blocks[OF lbl] kn have lk: "l < k" and ln: "l < n" and "lb  l" by auto
  define B where "B = step_3_c_inner_loop (A $$ (i_end,k) / x) l i_end (Suc i_end - i_begin) A"
  show ?case 
  proof (cases "i_end = l")
    case True
    hence id: "step_3_c x l k (p # bs) A = step_3_c x l k bs A" unfolding p by simp
    show ?thesis unfolding id
      by (rule IH[OF A kn lbl inv _ blocks(2-3) x], insert blocks(1), auto simp: p True)
  next
    case False note il = this
    hence id: "step_3_c x l k (p # bs) A = step_3_c x l k bs B" unfolding B_def p by simp
    from blocks[unfolded p] have 
      ib_block: "(i_begin,i_end)  set (identify_blocks A k)" and large: "i_end - i_begin  l - lb" by auto
    from identify_blocks[OF this(1)] 
    have ibe: "i_begin  i_end" "i_end < k" by auto
    have B: "B = mat n n (λ (i,j). if (i,j) = (i_end,k) then 0 else 
      if i_begin  i  i  i_end  j > k then A $$ (i,j) - A $$ (i_end,k) / x * A $$ (l + i - i_end,j) else A $$ (i,j))"
      unfolding B_def x
      by (rule step_3_c_inner_result[OF inv kn A lbl ib_block il large], insert x, auto)
    have Bn: "B  carrier_mat n n" unfolding B by auto
    have sdAB: "same_diag A B" unfolding B same_diag_def using ibe by auto
    have suAB: "same_upto k A B" using A kn unfolding B same_upto_def by auto
    have inv_ev: "ev_block n B" using same_diag_ev_block[OF sdAB inv(4)] .
    have inv_jb: "inv_upto jb B k" using same_upto_inv_upto_jb[OF suAB inv(1)] .
    have ib: "identify_blocks A k = identify_blocks B k" using identify_blocks_cong[OF kn sdAB suAB] .
    have inv_ut: "inv_from uppert B k" using inv(2) ibe unfolding B inv_from_def uppert_def by auto
    from x il ibe kn lk have xB: "x = B $$ (l,k)" by (auto simp: B)
    {
      fix be
      assume *: "be  snd ` set bs" "be  {l, k}" "be < n" 
      hence "B $$ (be, k) = 0" using kn blocks(2)[of be] unfolding B
        by (cases "be = i_end", auto)
    } note blocksB = this
    have bs: "set bs  set (identify_blocks A k)" using blocks(1) by auto
    have inv_oz: "inv_at one_zero B k" using inv(3) ibe kn unfolding B inv_at_def one_zero_def by simp
    show ?thesis unfolding id 
      using IH[OF Bn kn, folded ib, OF lbl inv_jb inv_ut inv_oz inv_ev bs blocksB blocks(3) xB x(2)]
      using same_diag_trans[OF sdAB] same_upto_trans[OF suAB]
      by auto
  qed
qed

lemma step_3_main_inv: "A  carrier_mat n n 
   k > 0
   inv_all uppert A 
   ev_block n A 
   inv_upto jb A k
   inv_all jb (step_3_main n k A)  same_diag A (step_3_main n k A)"
proof (induct k A taking: n rule: step_3_main.induct)
  case (1 k A)
  from 1(2-) have A: "A  carrier_mat n n" and k: "k > 0" and
    inv: "inv_all uppert A" "ev_block n A" "inv_upto jb A k" by auto
  note [simp] = step_3_main.simps[of n k A]
  show ?case
  proof (cases "k  n")
    case True
    thus ?thesis using inv_uptoD[OF inv(3)] 
      by (intro conjI inv_allI, auto)
  next
    case False
    hence kn: "k < n" by simp
    obtain B where B: "B = step_3_a (k - 1) k A" by auto
    note IH = 1(1)[OF False B]
    from A have Bn: "B  carrier_mat n n" unfolding B carrier_mat_def by simp
    from k have "k - 1 < k" by simp   
    {
      fix i
      assume "i < k"
      with ev_blockD[OF inv(2) _ k < n, of i] k < n have "A $$ (i,i) = A $$ (k,k)" by auto
    }
    hence "inv_from_bot (λA i. one_zero A i k) A (k - 1)"
      using inv_all_uppertD[OF inv(1), of k] 
      unfolding inv_from_bot_def one_zero_def by auto
    from step_3_a_inv[OF A k - 1 < k k < n inv(3) inv_all_imp_inv_from[OF inv(1)]
      this inv(2)] same_diag_ev_block[OF  _ inv(2)]
    have inv: "inv_from uppert B k" "ev_block n B" "inv_upto jb B k" 
      "inv_at one_zero B k" and sd: "same_diag A B" unfolding B by auto
    note evb = ev_blockD[OF inv(2)]
    obtain all_blocks where ab: "all_blocks = identify_blocks B k" by simp
    obtain blocks where blocks: "blocks = filter (λ block. B $$ (snd block, k)  0) all_blocks" by simp
    obtain F where F: "F = (if blocks = [] then B
       else let (l_begin,l) = find_largest_block (hd blocks) (tl blocks); x = B $$ (l, k); C = step_3_c x l k blocks B;
            D = mult_col_div_row (inverse x) k C; E = swap_cols_rows_block (Suc l) k D
        in E)" by simp
    note IH = IH[OF ab blocks F]
    have Fn: "F  carrier_mat n n" unfolding F Let_def carrier_mat_def using Bn 
      by (simp split: prod.splits)
    have inv: "inv_all uppert F  same_diag A F  inv_upto jb F (Suc k)" 
    proof (cases "blocks = []")
      case True
      hence F: "F = B" unfolding F by simp
      have lo: "inv_at (lower_one k) B k" 
      proof
        fix i
        assume i: "i < n"
        note lower_one_def[simp]
        show "lower_one k B i k" 
        proof (cases "B $$ (i,k) = 0")
          case False note nz = this
          note oz = inv_atD[OF inv(4) i, unfolded one_zero_def]
          from nz oz have "i  k" by auto
          show ?thesis 
          proof (cases "i = k")
            case False
            with i  k have "i < k" by auto
            with nz oz have ev: "B $$ (i,i) = B $$ (k,k)" unfolding diff_ev_def by auto
            have "(identify_block B i, i)  set all_blocks" unfolding ab
            proof (rule identify_blocks_rev[OF _ k < n]) 
              show "B $$ (i, Suc i) = 0  Suc i < k  Suc i = k"
              proof (cases "Suc i = k")
                case False
                with i < k k < n have "Suc i < k" "Suc i < n" by simp_all
                with nz oz have "B $$ (i, Suc i)  1" by simp
                with inv_uptoD[OF inv(3) i < n ‹Suc i < n ‹Suc i < k, unfolded jb_def]
                have "B $$ (i, Suc i) = 0" by simp
                thus ?thesis using ‹Suc i < k by simp
              qed simp
            qed
            with arg_cong[OF blocks = [][unfolded blocks], of set] have "B $$ (i,k) = 0" by auto
            with nz show ?thesis by auto
          qed auto
        qed auto
      qed
      have inv_jb: "inv_upto jb B (Suc k)"
      proof (rule inv_upto_Suc[OF inv(3)])
        fix i
        assume "i < n"
        from inv_atD[OF lo i < n, unfolded lower_one_def]
        show "jb B i k" unfolding jb_def by auto
      qed
      from inv_from_upto_at_all_ev_block[OF inv(3,1) lo inv(2) _ k < n] lower_one_diff_uppert
      have "inv_all uppert B" by auto
      with inv inv_jb sd
      show ?thesis unfolding F by simp
    next
      case False
      obtain l_start l where l: "find_largest_block (hd blocks) (tl blocks) = (l_start, l)" by force
      obtain x where x: "x = B $$ (l,k)" by simp
      obtain C where C: "C = step_3_c x l k blocks B" by simp
      obtain D where D: "D = mult_col_div_row (inverse x) k C" by auto
      obtain E where E: "E = swap_cols_rows_block (Suc l) k D" by auto
      from find_largest_block[OF False l] have lb: "(l_start,l)  set blocks"
        and llarge: " i_begin i_end. (i_begin,i_end)  set blocks  l - l_start  i_end - i_begin" by auto
      from lb have x0: "x  0" unfolding blocks x by simp
      {
        fix i_start i_end
        assume "(i_start,i_end)  set blocks"
        hence "(i_start,i_end)  set (identify_blocks B k)" unfolding blocks ab by simp
        with identify_blocks[OF this]
        have "i_end < k" "(i_start,i_end)  set (identify_blocks B k)" by auto
      } note block_bound = this
      from block_bound[OF lb]
      have lk: "l < k" and lblock: "(l_start, l)  set (identify_blocks B k)" by auto
      from lk k < n have ln: "l < n" by simp
      from evb[OF l < n k < n]
      have Bll: "B $$ (l,l) = B $$ (k,k)" .
      from False have F: "F = E" unfolding E D C x F l Let_def by simp
      from Bn have Cn: "C  carrier_mat n n" unfolding C carrier_mat_def by simp
      {
        fix be
        assume nmem: "be  snd ` set blocks" and belk: "be  {l, k}" and be: "be < n"
        have "B $$ (be, k) = 0"
        proof (rule ccontr)
          assume nz: "¬ ?thesis"
          note oz = inv_atD[OF inv(4) be, unfolded one_zero_def]
          from belk oz be nz have "be < k" by auto
          obtain bb where ib: "identify_block B be = bb" by force
          note ib_inv = identify_block[OF ib]
          have "B $$ (be, Suc be) = 0  Suc be < k  Suc be = k"
          proof (cases "Suc be = k")
            case False
            with be < k have sbek: "Suc be < k" by auto
            from inv_uptoD[OF inv(3) be < n _ sbek] sbek kn have "jb B be (Suc be)" by auto
            from this[unfolded jb_def] have 01: "B $$ (be, Suc be)  {0,1}" by auto
            from 01 oz sbek nz have "B $$ (be, Suc be) = 0" by auto
            with sbek show ?thesis by auto
          qed auto
          from identify_blocks_rev[OF this kn] 
             nz nmem show False unfolding ab blocks by force 
        qed
      }
      note inv3 = step_3_c_inv[OF Bn k < n lblock inv(3,1,4,2) _ this llarge x x0, of blocks, folded C,
        unfolded ab blocks]
      from inv3 have sdC: "same_diag B C" and suC: "same_upto k B C" by auto
      note sd = same_diag_trans[OF sd sdC]
      from Bll sdC ln k < n 
      have Cll: "C $$ (l,l) = C $$ (k,k)" unfolding same_diag_def by auto
      from same_diag_ev_block[OF sdC inv(2)] same_upto_inv_upto_jb[OF suC inv(3)] inv3 
      have inv: "inv_all uppert C" "ev_block n C"
        "inv_upto jb C k" "inv_at (single_non_zero l k x) C k" by auto
      from x0 have "inverse x  0" by simp
      from Cn have Dn: "D  carrier_mat n n" unfolding D carrier_mat_def by simp
      {
        fix i j
        assume i: "i < n" and j: "j < n"
        with Cn have dC: "i < dim_row C" "i < dim_col C" "j < dim_row C" "j < dim_col C" by auto
        let ?c = "C $$ (i,j)"
        let ?x = "inverse x"
        have "D $$ (i,j) = (if i = l  j = k then 1 else if i = k  j  k then x * ?c else ?c)"
          unfolding D
        proof (subst mult_col_div_index_row[OF dC ‹inverse x  0], unfold inverse_inverse_eq)
          note at = inv_atD[OF inv(4) i < n, unfolded single_non_zero_def]
          show "(if i = k  j  i then x * ?c 
            else if j = k  j  i then ?x * ?c else ?c) =
            (if i = l  j = k then 1 else if i = k  j  k then x * ?c else ?c)" (is "?l = ?r")
          proof (cases "(i,j) = (l,k)")
            case True
            with lk have "?l = ?x * ?c" by auto
            also have " = 1" using at True ‹inverse x  0 by auto
            finally show ?thesis using True by simp
          next
            case False note neq = this
            have "?l = (if i = k  j  k then x * ?c else ?c)"
            proof (cases "i = k  j  k  j = k  i  k")
              case True
              thus ?thesis
              proof
                assume *: "i = k  j  k"
                hence l: "?l = x * ?c" by simp
                show ?thesis using * neq unfolding l by simp
              next
                assume *: "j = k  i  k"
                hence "?l = ?x * ?c" using lk by auto
                from * neq have "i  l" and **: "¬ (i = k  j  k)" by auto
                from at i  l * have "?c = 0" by auto
                with ?l = ?x * ?c ** show ?thesis by auto
              qed
            qed auto
            also have " = ?r" using False by auto
            finally show ?thesis .
          qed
        qed
      } note D = this 
      have sD[simp]: " i. i < n  D $$ (i,i) = C $$ (i,i)" using lk by (auto simp: D)
      from C $$ (l,l) = C $$ (k,k) l < n k < n
      have Dll: "D $$ (l,l) = D $$ (k,k)" by simp      
      have sdD: "same_diag C D" unfolding same_diag_def by simp
      note sd = same_diag_trans[OF sd sdD]
      from same_diag_ev_block[OF sdD inv(2)] have invD: "ev_block n D" .
      note inv = inv_uptoD[OF inv(3), unfolded jb_def] inv_all_uppertD[OF inv(1)] 
        inv_atD[OF inv(4), unfolded single_non_zero_def]
      moreover have "inv_all uppert D"
        by (intro inv_allI, insert inv(2) lk, auto simp: uppert_def D)
      moreover have suD: "same_upto k C D" 
      proof 
        fix i j
        assume i: "i < n" and j: "j < k"
        with kn have jn: "j < n" by simp
        show "C $$ (i, j) = D $$ (i, j)"  
          unfolding D[OF i jn] using j k
           inv(1)[OF i jn j] i j by auto
      qed
      from same_upto_inv_upto_jb[OF suD ‹inv_upto jb C k]
      have "inv_upto jb D k" .
      moreover 
      let ?single_one = "single_one l k"
      have "inv_at ?single_one D k" 
        by (intro inv_atI, insert inv(3) D[OF _ k < n] ln, auto simp: single_one_def)
      ultimately
      have inv: "inv_all uppert D" "ev_block n D"
        "inv_upto jb D k" "inv_at ?single_one D k" using invD by blast+
      note inv = inv_uptoD[OF inv(3), unfolded jb_def] 
        inv_all_uppertD[OF inv(1)] 
        inv_atD[OF inv(4), unfolded single_one_def] 
        ev_blockD[OF inv(2)]
      from suC suD have suD: "same_upto k B D" unfolding same_upto_def by auto
      let ?I = "λ j. if j = Suc l then k else if Suc l < j  j  k then j - 1 else j"
      let ?I' = "λ j. if j = Suc l then k else j - 1"
      {
        fix i j
        assume i: "i < n" and j: "j < n"
        with Dn lk k < n
        have dims: "i < dim_row D" "i < dim_col D" "j < dim_row D" "j < dim_col D" 
          "Suc l  k" "k < dim_row D" "k < dim_col D" by auto
        have "E $$ (i,j) = D $$ (?I i, ?I j)" 
          unfolding E by (rule subst swap_cols_rows_block_index[OF dims])
      } note E = this
      {
        fix i
        assume i: "i < n"
        from l < k have "l  Suc l" "Suc l  k" by auto
        have "E $$ (i,i) = D $$ (i,i)" unfolding E[OF i i]
          by (rule inv(4), insert i k < n, auto) 
      } note Ed = this
      from Ed have ed: "same_diag D E" unfolding same_diag_def by auto
      note sd = same_diag_trans[OF sd ed]
      have "ev_block n E" using same_diag_ev_block[OF ed ‹ev_block n D] by auto
      moreover have Eut: "inv_all uppert E" 
      proof (intro inv_allI, unfold uppert_def, intro impI)
        fix i j 
        assume i: "i < n" and j: "j < n" and ji: "j < i"
        have "?I i < n" using i k < n by auto
        show "E $$ (i,j) = 0"
        proof (cases "?I j < ?I i")
          case True
          from inv(2)[OF this ?I i < n] show ?thesis unfolding E[OF i j] .
        next
          case False
          have "?I i  ?I j" using ji lk by (auto split: if_splits)
          with False have ij: "?I i < ?I j" by simp
          from ij ji have jl: "j = Suc l" using lk by (auto split: if_splits)
          with ji ij have il: "i > Suc l" "i  k" by (auto split: if_splits)
          from jl il have Eij: "E $$ (i,j) = D $$ (i-1,k)" unfolding E[OF i j] by simp
          have "i - 1 < n" "i - 1  {k, l}" using i il by auto
          with inv(3)[of "i-1"] have D: "D $$ (i-1,k) = 0" by auto
          show ?thesis unfolding Eij D by simp
        qed
      qed
      moreover 
      from same_diag_trans[OF ‹same_diag B C ‹same_diag C D] have "same_diag B D" .
      from identify_blocks_cong[OF k < n this suD] 
      have idb: "identify_blocks B k = identify_blocks D k" .
      have "inv_upto jb E (Suc k)" 
      proof (intro inv_uptoI)
        fix i j
        assume i: "i < n" and j: "j < n" and "j < Suc k"
        hence jk: "j  k" by simp
        show "jb E i j"
        proof (cases "E $$ (i,j) = 0  j = i")
          case True
          thus ?thesis unfolding jb_def by auto
        next
          case False note enz = this 
          from inv(4)[OF i j] have same_ev: "D $$ (i,i) = D $$ (j,j)" .
          note inv2 = inv_all_uppertD[OF Eut _ i, of j]
          from False inv2 have "¬ j < i" by auto
          with False have ji: "j > i" by auto
          have "E $$ (i,j)  {0,1}  (j  Suc i  E $$ (i,j) = 0)"
          proof (cases "j  l")
            case True note jl = this
            with ji lk have il: "i  l" and jk: "j < k" by auto
            have id: "E $$ (i,j) = D $$ (i,j)" unfolding E[OF i j] using jl il by simp
            from inv(1)[OF i j jk] ji
            show ?thesis unfolding id by auto
          next
            case False note jl = this
            show ?thesis
            proof (cases "j = Suc l")
              case True note jl = this
              with ji lk have il: "i  l" "i  k" by auto
              have id: "E $$ (i,j) = D $$ (i,k)" unfolding E[OF i j] using jl il by auto
              from inv(3)[OF i] jl il
              show ?thesis unfolding id by (cases "i = l", auto)
            next
              case False
              with jl jk kn have jl: "j > Suc l" and jk: "j - 1 < k" and jn: "j - 1 < n" by auto
              with jk have id: "?I j = j - 1" by auto
              note jb = inv(1)[OF _ jn jk]
              show ?thesis
              proof (cases "i < Suc l")
                case True note il = this
                with id have id: "E $$ (i,j) = D $$ (i,j - 1)" unfolding E[OF i j] by auto
                show ?thesis 
                proof (cases "i = j - 2")
                  case False
                  thus ?thesis unfolding id using jb[OF i] il jl by auto
                next
                  case True
                  with il jl have *: "j = Suc (Suc l)" "i = l" by auto
                  with id have id: "E $$ (i,j) = D $$ (l,Suc l)" by auto
                  from * jl jk have neq: "Suc l  k" by auto
                  from lblock[unfolded idb] have "(l_start, l)  set (identify_blocks D k)" .
                  from this[unfolded identify_blocks_iff[OF kn]] neq
                  have "D $$ (l, Suc l)  1" by auto
                  with jb[OF i] il jl ji * have "D $$ (l, Suc l) = 0" by auto
                  thus ?thesis unfolding id by simp
                qed
              next
                case False note il = this
                show ?thesis
                proof (cases "i = Suc l")
                  case True 
                  with id have id: "E $$ (i,j) = D $$ (k,j - 1)" unfolding E[OF i j] by auto
                  from inv(2)[OF jk kn] show ?thesis unfolding id by simp
                next
                  case False
                  with il jl ji jk kn have il: "i > Suc l" and ik: "i < k" and i_n: "i - 1 < n" by auto
                  with id have id: "E $$ (i,j) = D $$ (i - 1, j - 1)" unfolding E[OF i j] by auto
                  show ?thesis unfolding id using jb[OF i_n] il jl ji by auto
                qed
              qed
            qed
          qed
          thus "jb E i j" unfolding jb_def Ed[OF i] Ed[OF j] same_ev by auto
        qed
      qed
      ultimately show ?thesis using sd unfolding F by simp
    qed
    hence inv: "inv_all uppert F" "ev_block n F" "inv_upto jb F (Suc k)" 
      and sd: "same_diag A F" using same_diag_ev_block[OF _ ‹ev_block n A] by auto
    have "0 < Suc k" by simp
    note IH = IH[OF Fn this inv(1-3)]
    have id: "step_3_main n k A = step_3_main n (Suc k) F" using kn 
      by (simp add: F Let_def blocks ab B)
    from same_diag_trans[OF sd] IH
    show ?thesis unfolding id by auto
  qed
qed

lemma step_1_2_inv: 
  assumes A: "A  carrier_mat n n"
  and upper_t: "upper_triangular A"
  and Bid: "B = step_2 (step_1 A)"
  shows "inv_all uppert B" "inv_all diff_ev B" "ev_blocks B" 
proof -
  from A have d: "dim_row A = n" by simp
  let ?B = "step_2 (step_1 A)"
  from upper_triangularD[OF upper_t] have inv: "inv_all uppert A"
    unfolding inv_all_def uppert_def using A by auto
  from upper_t have inv2: "inv_part diff_ev A 0 0"
    unfolding inv_part_def diff_ev_def by auto
  have inv3: "ev_blocks_part 0 (step_1 A)"
    by (rule ev_blocks_partI, auto)
  have A1: "step_1 A  carrier_mat n n" using A unfolding carrier_mat_def by auto
  from A1 have d1: "dim_row (step_1 A) = n" unfolding carrier_mat_def by simp
  have B: "?B  carrier_mat n n" using A unfolding carrier_mat_def by auto
  from B have d2: "dim_row ?B = n" unfolding carrier_mat_def by simp
  have "inv_all uppert (step_1 A)  inv_all diff_ev (step_1 A)" unfolding step_1_def d
    by (rule step_1_main_inv[OF _ A inv inv2], simp)
  hence "inv_all uppert (step_1 A)" and "inv_all diff_ev (step_1 A)" by auto
  from step_2_main_inv[OF A1 this inv3]
  show "inv_all uppert B" "inv_all diff_ev B" "ev_blocks B" 
    unfolding step_2_def d d1 Bid by auto
qed

definition inv_all' :: "('a mat  nat  nat  bool)  'a mat  bool" where
  "inv_all' p A   i j. i < dim_row A  j < dim_row A  p A i j"

private lemma lookup_other_ev_None: assumes "lookup_other_ev ev k A = None"
  and "i < k"
  shows "A $$ (i,i) = ev"
  using assms by (induct ev k A rule: lookup_other_ev.induct, auto split: if_splits)
  (insert less_antisym, blast) 

private lemma lookup_other_ev_Some: assumes "lookup_other_ev ev k A = Some i"
  shows "i < k  A $$ (i,i)  ev  ( j. i < j  j < k  A $$ (j,j) = ev)"
  using assms by (induct ev k A rule: lookup_other_ev.induct, auto split: if_splits)
  (insert less_SucE, blast)


lemma partition_jb: assumes A: "(A :: 'a mat)  carrier_mat n n"
  and inv: "inv_all uppert A" "inv_all diff_ev A" "ev_blocks A"
  and part: "partition_ev_blocks A [] = bs"
  shows "A = diag_block_mat bs" " B. B  set bs  inv_all' uppert B  ev_block (dim_col B) B  dim_row B = dim_col B"
proof -
  have diag: "diag_block_mat [A] = A" using A by auto
  {
    fix cs 
    assume *: " C. C  set cs  dim_row C = dim_col C  inv_all' uppert C  ev_block (dim_col C) C" "partition_ev_blocks A cs = bs"
    from inv have inv: "inv_all' uppert A" "inv_all' diff_ev A" "ev_blocks_part n A"
      unfolding inv_all_def inv_all'_def ev_blocks_def using A by auto
    hence "diag_block_mat (A # cs) = diag_block_mat bs  ( B  set bs. inv_all' uppert B  ev_block (dim_col B) B  dim_row B = dim_col B)"
      using A *
    proof (induct n arbitrary: A cs bs rule: less_induct) 
      case (less n A cs bs)
      from less(5) have A: "A  carrier_mat n n" by auto
      hence dim: "dim_row A = n" "dim_col A = n" by auto
      let ?dim = "sum_list (map dim_col cs)"
      let ?C = "diag_block_mat cs"
      define C where "C = ?C"
      from less(6) have cs: " C. C  set cs  inv_all' uppert C  ev_block (dim_col C) C  dim_row C = dim_col C" by auto
      hence dimcs[simp]: "sum_list (map dim_row cs) = ?dim" by (induct cs, auto)
      from dim_diag_block_mat[of cs, unfolded dimcs] obtain nc where C: "?C  carrier_mat nc nc" unfolding carrier_mat_def by auto
      hence dimC: "dim_row C = nc" "dim_col C = nc" unfolding C_def by auto
      note bs = less(7)[unfolded partition_ev_blocks.simps[of A cs] Let_def dim, symmetric]
      show ?case
      proof (cases "n = 0")
        case True
        hence bs: "bs = cs" unfolding bs by simp
        thus ?thesis using cs A by (auto simp: Let_def True)
      next
        case False
        let ?n1 = "n - 1"
        let ?look = "lookup_other_ev (A $$ (?n1, ?n1)) ?n1 A"
        show ?thesis
        proof (cases ?look)
          case None
          from False None have bs: "bs = A # cs" unfolding bs by auto
          have ut: "inv_all' uppert A" using less(2) by auto
          from lookup_other_ev_None[OF None] have " i. i < n  A $$ (i,i) = A $$ (?n1, ?n1)"
            by (case_tac "i = ?n1", auto)
          hence evb: "ev_block n A" unfolding ev_block_def dim by metis
          from cs A ut evb show ?thesis unfolding bs by auto
        next
          case (Some i)
          let ?si = "Suc i"
          from lookup_other_ev_Some[OF Some] have i: "i < ?n1" and neq: "A $$ (i,i)  A $$ (?n1, ?n1)" 
            and between: "j. i < j  j < ?n1  A $$ (j,j) = A $$ (?n1, ?n1)" by auto
          define m where "m = n - ?si"
          from i False have si: "?si < n" by auto
          from False i have nsi: "n = ?si + m" unfolding m_def by auto
          obtain UL UR LL LR where split: "split_block A ?si ?si = (UL, UR, LL, LR)" by (rule prod_cases4)
          from split_block[OF split dim[unfolded nsi]] 
          have carr: "UL  carrier_mat ?si ?si" "UR  carrier_mat ?si m" "LL  carrier_mat m ?si" "LR  carrier_mat m m"
            and Ablock: "A = four_block_mat UL UR LL LR" by auto          
          hence dimLR: "dim_row LR = m" "dim_col LR = m" and dimUL: "dim_col UL = ?si" "dim_row UL = ?si" by auto
          from less(3)[unfolded inv_all'_def diff_ev_def] dim
          have diff: " i j. i < n  j < n  i < j  A $$ (i, i)  A $$ (j, j)  A $$ (i, j) = 0" by auto
          from less(2)[unfolded inv_all'_def uppert_def] dim
          have ut: " i j. i < n  j < n  j < i  A $$ (i, j) = 0" by auto
          let ?UR = "0m ?si m"
          have UR: "UR = ?UR"
          proof (rule eq_matI)
            fix ia j
            assume ij: "ia < dim_row (0m (Suc i) m)" "j < dim_col (0m (Suc i) m)"
            let ?j = "?si + j"
            have "UL $$ (ia,ia) = A $$ (ia,ia)" using ij carr unfolding Ablock by auto
            also have "  A $$ (?j, ?j)" 
            proof
              assume eq: "A $$ (ia,ia) = A $$ (?j, ?j)"
              from ij have rel: "ia  i" "i  ?j" "?j < n" using nsi i by auto
              from ev_blocks_part_leD[OF less(4) this eq[symmetric]] eq 
              have eq: "A $$ (i,i) = A $$ (?j,?j)" by auto
              also have " = A $$ (?n1, ?n1)" using between[of ?j] rel by (cases "?j = ?n1", auto)
              finally show False using neq by auto
            qed
            also have "A $$ (?si + j, ?si + j) = LR $$ (j,j)" using ij carr unfolding Ablock by auto
            finally show "UR $$ (ia, j) = 0m (Suc i) m $$ (ia, j)"            
              using diff[of ia "?si + j", unfolded Ablock] ij nsi carr by auto
          qed (insert carr, auto)
          let ?LL = "0m m ?si"
          have LL: "LL = ?LL"
          proof (rule eq_matI)
            fix ia j
            show "ia < dim_row (0m m (Suc i))  j < dim_col (0m m (Suc i))  LL $$ (ia, j) = 0m m (Suc i) $$ (ia, j)"
              using ut[of "?si + ia" j, unfolded Ablock] nsi carr by auto
          qed (insert carr, auto)
          have utUL: "inv_all' uppert UL"unfolding inv_all'_def uppert_def
          proof (intro allI impI)
            fix i j
            show "i < dim_row UL  j < dim_row UL  j < i  UL $$ (i, j) = 0"
              using ut[of i j, unfolded Ablock] using nsi carr by auto
          qed
          have diffUL: "inv_all' diff_ev UL"unfolding inv_all'_def diff_ev_def
          proof (intro allI impI)
            fix i j
            show "i < dim_row UL  j < dim_row UL  i < j  UL $$ (i, i)  UL $$ (j, j)  UL $$ (i, j) = 0"
              using diff[of i j, unfolded Ablock] using nsi carr by auto
          qed
          have evbUL: "ev_blocks_part ?si UL"unfolding ev_blocks_part_def
          proof (intro allI impI)
            fix ia j k
            show "ia < j  j < k  k < Suc i  UL $$ (k, k) = UL $$ (ia, ia)  UL $$ (j, j) = UL $$ (ia, ia)"
              using less(4)[unfolded Ablock ev_blocks_part_def, rule_format, of ia j k] using nsi carr by auto
          qed
          have utLR: "inv_all' uppert LR" unfolding inv_all'_def uppert_def
          proof (intro allI impI)
            fix i j
            show "i < dim_row LR  j < dim_row LR  j < i  LR $$ (i, j) = 0"
              using ut[of "?si + i" "?si + j", unfolded Ablock] nsi carr by auto
          qed
          have evbLR: "ev_block (dim_row LR) LR" unfolding ev_block_def
          proof (intro allI impI)
            fix i j
            show "i < dim_row LR  j < dim_row LR  LR $$ (i, i) = LR $$ (j, j)"
              using between[of "?si + i"] between[of "?si + j"] carr nsi
              unfolding Ablock by auto (metis One_nat_def Suc_lessI diff_Suc_1)
          qed
          from False Some split have bs: "partition_ev_blocks UL (LR # cs) = bs" unfolding bs by auto
          have IH: "diag_block_mat (UL # LR # cs) = diag_block_mat bs  (Bset bs. inv_all' uppert B  ev_block (dim_col B) B  dim_row B = dim_col B)"
            by (rule less(1)[OF si utUL diffUL evbUL carr(1) _ bs], insert dimLR evbLR utLR cs, auto)
          have "diag_block_mat (A # cs) = diag_block_mat (UL # LR # cs)" 
            unfolding diag_block_mat.simps dim C_def[symmetric] dimC dimLR dimUL Let_def
              index_mat_four_block(2-3) Ablock UR LL
            using assoc_four_block_mat[of UL LR C] dimC carr by simp
          with IH show ?thesis by auto
        qed
      qed
    qed
  }
  from this[of Nil, OF _ part] show "A = diag_block_mat bs" " B. B  set bs  inv_all' uppert B  ev_block (dim_col B) B  dim_row B = dim_col B"
    unfolding diag by fastforce+
qed

lemma uppert_to_jb: assumes ut: "inv_all uppert A" and "A  carrier_mat n n"
shows "inv_upto jb A 1"
proof (rule inv_uptoI)
  fix i j
  assume "i < n" "j < n" and "j < 1"
  hence j: "j = 0" and jn: "0 < n" by auto
  show "jb A i j" unfolding jb_def j using inv_all_uppertD[OF ut _ i < n, of 0]
    by auto
qed

lemma jnf_vector: assumes A: "A  carrier_mat n n"
  and jb: " i j. i < n  j < n  jb A i j"
  and evb: "ev_block n A"
shows "jordan_matrix (jnf_vector A) = (A :: 'a mat)"
  "0  fst ` set (jnf_vector A)" 
proof -
  from A have "dim_row A = n" by simp
  hence id: "jnf_vector A = jnf_vector_main n A" unfolding jnf_vector_def by auto
  let ?map = "map (λ(n, a). jordan_block n (a :: 'a))"
  let ?B = "λ k. diag_block_mat (?map (jnf_vector_main k A))"
  {
    fix k
    assume "k  n"
    hence "( i j. i < k  j < k  ?B k $$ (i,j) = A $$ (i,j))
       diag_block_mat (?map (jnf_vector_main k A))  carrier_mat k k
       0  fst ` set (jnf_vector_main k A)"
    proof (induct k rule: less_induct)
      case (less sk)
      show ?case
      proof (cases sk)
        case (Suc k)
        obtain b where ib: "identify_block A k = b" by force
        let ?ev = "A $$ (b,b)"
        from ib have id: "jnf_vector_main sk A = jnf_vector_main b A @ [(Suc k - b, ?ev)]" unfolding Suc by simp
        let ?c = "Suc k - b"
        define B where "B = ?B b"
        define C where "C = jordan_block ?c ?ev"
        have C: "C  carrier_mat ?c ?c" unfolding C_def by auto
        let ?FB = "λ Bb Cc. four_block_mat Bb (0m (dim_row Bb) (dim_col Cc)) (0m (dim_row Cc) (dim_col Bb)) Cc"
        from identify_block_le'[OF ib] have bk: "b  k" .
        with Suc less(2) have "b < sk" "b  n" by auto
        note IH = less(1)[OF this, folded B_def]
        have B: "B  carrier_mat b b" using IH by simp
        from bk Suc have sk: "sk = b + ?c" by auto
        show ?thesis unfolding id map_append list.simps diag_block_mat_last split B_def[symmetric] C_def[symmetric] Let_def
        proof (intro allI conjI impI)
          show "?FB B C  carrier_mat sk sk" unfolding sk using four_block_carrier_mat[OF B C] .
          fix i j
          assume i: "i < sk" and j: "j < sk"
          with jb sk  n 
          have jb: "jb A i j" by auto
          have ut: "uppert A i j" by (rule jb_imp_uppert[OF jb])
          have de: "diff_ev A i j" by (rule jb_imp_diff_ev[OF jb])
          from B C have dim: "dim_row B = b" "dim_col B = b" "dim_col C = ?c" "dim_row C = ?c" by auto
          from sk B C i j have "i < dim_row B + dim_row C" "j < dim_col B + dim_col C" by auto
          note id = index_mat_four_block(1)[OF this, unfolded dim]
          have id: "?FB B C $$ (i,j) = 
          (if i < b then if j < b then B $$ (i, j) else 0 
            else if j < b then 0 else C $$ (i - b, j - b))" 
            unfolding id dim using i j sk by auto
          show "?FB B C $$ (i,j) = A $$ (i,j)" 
          proof (cases "i < b  j < b")
            case True (* upper left *)
            hence "?FB B C $$ (i,j) = B $$ (i,j)" unfolding id by auto
            with IH True show ?thesis by auto
          next
            case False note not_ul = this
            note ib = identify_block[OF ib]
            show ?thesis
            proof (cases "¬ i < b  j < b  i < b  ¬ j < b")
              case True (* not on main diagonal *)
              hence id: "?FB B C $$ (i,j) = 0" unfolding id by auto
              show ?thesis
              proof (cases "j < i") 
                case True (* lower left *)
                with ut show ?thesis unfolding id uppert_def by auto
              next
                case False (* upper right *)
                with True have *: "j  b" "i < b" "j > i" by auto
                have "A $$ (i,j) = 0" 
                proof (rule ccontr)
                  assume "A $$ (i,j)  0"
                  with jb[unfolded jb_def] *
                  have ji: "j = b" "i = b - 1" "b > 0" and no_border: "A $$ (i, i) = A $$ (j, j)" "A $$ (i,j) = 1" by auto
                  from no_border[unfolded ji] ib(2) b > 0 show False by auto                
                qed
                thus ?thesis unfolding id by simp
              qed
            next
              case False (* lower right *)
              with not_ul have *: "¬ i < b" "¬ j < b" by auto
              hence id: "?FB B C $$ (i,j) = C $$ (i - b, j - b)" unfolding id by auto
              from * i j have ijc: "i - b < ?c" "j - b < ?c" unfolding sk by auto
              have id: "?FB B C $$ (i,j) = (if i - b = j - b then ?ev else if Suc (i - b) = j - b then 1 else 0)"
                unfolding id unfolding C_def jordan_block_index(1)[OF ijc] ..
              show ?thesis 
              proof (cases "i - b = j - b")
                case True
                hence id: "?FB B C $$ (i,j) = ?ev" unfolding id by simp
                from True * have ij: "j = i" by auto
                have i_n: "i < n" using i sk  n by auto 
                have b_n: "b < n" using b < sk sk  n by auto
                from ib(3)[of i] True * i j Suc ev_blockD[OF evb i_n b_n] have "A $$ (i,j) = ?ev" unfolding ij by auto
                with id show ?thesis by simp
              next
                case False note neq = this
                show ?thesis
                proof (cases "j - b = Suc (i - b)")
                  case True
                  hence id: "?FB B C $$ (i,j) = 1" unfolding id by simp
                  from True * have ij: "j = Suc i" by auto
                  from ib(3)[of i] True * i j Suc have "A $$ (i,j) = 1" unfolding ij by auto
                  with id show ?thesis by simp
                next
                  case False
                  with neq have id: "?FB B C $$ (i,j) = 0" unfolding id by simp
                  from * neq False have "i  j" "Suc i  j" by auto
                  with jb[unfolded jb_def] have "A $$ (i,j) = 0" by auto
                  with id show ?thesis by simp
                qed
              qed
            qed
          qed
        qed (insert bk IH, auto)
      qed auto
    qed
  }
  from this[OF le_refl] A
  show "jordan_matrix (jnf_vector A) = A" "0  fst ` set (jnf_vector A)"
    unfolding id jordan_matrix_def by auto
qed

end


lemma triangular_to_jnf_vector: 
  assumes A: "A  carrier_mat n n"
  and upper_t: "upper_triangular A"
  shows "jordan_nf A (triangular_to_jnf_vector A)"
proof -
  from A have d: "dim_row A = n" by simp
  let ?B = "step_2 (step_1 A)"
  let ?J = "triangular_to_jnf_vector A"
  have A1: "step_1 A  carrier_mat n n" using A unfolding carrier_mat_def by simp
  from similar_mat_trans[OF step_2_similar step_1_similar, OF A1 A]
  have sim: "similar_mat ?B A" .
  have A1: "step_1 A  carrier_mat n n" using A unfolding carrier_mat_def by auto
  from A1 have d1: "dim_row (step_1 A) = n" unfolding carrier_mat_def by simp
  have B: "?B  carrier_mat n n" using A unfolding carrier_mat_def by auto
  from B have d2: "dim_row ?B = n" unfolding carrier_mat_def by simp
  define Cs where "Cs = partition_ev_blocks ?B []"
  from step_1_2_inv[OF A upper_t refl]
  have inv: "inv_all n uppert ?B" "inv_all n diff_ev ?B" "ev_blocks n ?B" by auto
  from partition_jb[OF B inv, of Cs] have BC: "?B = diag_block_mat Cs"
    and Cs: " C. C  set Cs  inv_all' uppert C  ev_block (dim_col C) C  dim_row C = dim_col C" unfolding Cs_def by auto
  define D where "D = map step_3 Cs"
  let ?D = "diag_block_mat D"
  let ?CD = "map (λ C. (C, (jnf_vector o step_3) C)) Cs"
  {
    fix C D
    assume mem: "(C,D)  set ?CD"
    hence DC: "D = jnf_vector (step_3 C)" and C: "C  set Cs" by auto
    let ?D = "step_3 C"
    define n where "n = dim_col C"
    from Cs[OF C] have C: "inv_all n uppert C" "ev_block n C" "C  carrier_mat n n" 
      unfolding inv_all'_def inv_all_def n_def carrier_mat_def by auto
    from step_3_similar[OF C(3)] have sim: "similar_mat C ?D" by (rule similar_mat_sym)
    from similar_matD[OF sim] C have D: "?D  carrier_mat n n" unfolding carrier_mat_def by auto
    from C(3) have dimC: "dim_row C = n" by auto
    from step_3_main_inv[OF C(3) _ C(1,2) uppert_to_jb[OF C(1) C(3)]]
    have "inv_all n jb (step_3 C)" and sd: "same_diag n C (step_3 C)" unfolding step_3_def dimC by auto
    hence jbD: " i j. i < n  j < n  jb ?D i j" unfolding inv_all_def DC by auto
    from same_diag_ev_block[OF sd C(2)] have "ev_block n (step_3 C)" by auto
    from jnf_vector[OF D jbD this] have "jordan_matrix D = ?D" "0  fst ` set D" unfolding DC by auto
    with sim have "jordan_nf C D" unfolding jordan_nf_def by simp
  } note jnf_blocks = this
  have id: "map fst ?CD = Cs" by (induct Cs, auto)
  have id2: "map snd ?CD = map (jnf_vector o step_3) Cs" by (induct Cs, auto)
  have J: "?J = concat (map (jnf_vector  step_3) Cs)" unfolding 
    triangular_to_jnf_vector_def Let_def Cs_def ..
  from jordan_nf_diag_block_mat[of ?CD, OF jnf_blocks, unfolded id id2]
  have jnf: "jordan_nf (diag_block_mat Cs) ?J" unfolding J .
  hence "similar_mat (diag_block_mat Cs) (jordan_matrix ?J)" 
    unfolding jordan_nf_def by auto
  from similar_mat_sym[OF similar_mat_trans[OF similar_mat_sym[OF this] sim[unfolded BC]]] jnf
  show ?thesis unfolding jordan_nf_def by auto
qed

(* hide auxiliary definitions and functions *)
hide_const 
  lookup_ev
  find_largest_block
  swap_cols_rows_block
  identify_block
  identify_blocks_main
  identify_blocks
  inv_all inv_all' same_diag
  jb uppert diff_ev ev_blocks ev_block
  step_1_main step_1 
  step_2_main step_2 
  step_3_a step_3_c step_3_c_inner_loop step_3 
  jnf_vector_main


subsection ‹Combination with Schur-decomposition›

definition jordan_nf_via_factored_charpoly :: "'a :: conjugatable_ordered_field mat  'a list  (nat × 'a)list"
  where "jordan_nf_via_factored_charpoly A es = 
    triangular_to_jnf_vector (schur_upper_triangular A es)"

lemma jordan_nf_via_factored_charpoly: assumes A: "A  carrier_mat n n"
  and linear: "char_poly A = ( a  es. [:- a, 1:])"
  shows "jordan_nf A (jordan_nf_via_factored_charpoly A es)"
proof -
  let ?B = "schur_upper_triangular A es"
  let ?J = "jordan_nf_via_factored_charpoly A es"
  from schur_upper_triangular[OF A linear]
  have B: "?B  carrier_mat n n" "upper_triangular ?B" and AB: "similar_mat A ?B" by auto
  from triangular_to_jnf_vector[OF B] have "jordan_nf ?B ?J" 
    unfolding jordan_nf_via_factored_charpoly_def .
  with similar_mat_trans[OF AB] show "jordan_nf A ?J" unfolding jordan_nf_def by blast
qed


lemma jordan_nf_exists: assumes A: "A  carrier_mat n n"
  and linear: "char_poly A = ( (a :: 'a :: conjugatable_ordered_field)  as. [:- a, 1:])"
  shows "n_as. jordan_nf A n_as"
  using jordan_nf_via_factored_charpoly[OF A linear] by blast

lemma jordan_nf_iff_linear_factorization: fixes A :: "'a :: conjugatable_ordered_field mat"
  assumes A: "A  carrier_mat n n" 
  shows "( n_as. jordan_nf A n_as) = ( as. char_poly A = ( a  as. [:- a, 1:]))"
    (is "?l = ?r")
proof
  assume ?r
  thus ?l using jordan_nf_exists[OF A] by auto
next
  assume ?l
  then obtain n_as where jnf: "jordan_nf A n_as" by auto
  show ?r unfolding jordan_nf_char_poly[OF jnf] expand_powers[of "λ a. [: -a, 1:]" n_as] by blast
qed

lemma similar_iff_same_jordan_nf: fixes A :: "complex mat" 
  assumes A: "A  carrier_mat n n" and B: "B  carrier_mat n n"
  shows "similar_mat A B = (jordan_nf A = jordan_nf B)" 
proof 
  show "similar_mat A B  jordan_nf A = jordan_nf B" 
    by (intro ext, auto simp: jordan_nf_def, insert similar_mat_trans similar_mat_sym, blast+)
  assume id: "jordan_nf A = jordan_nf B" 
  from char_poly_factorized[OF A] obtain as where "char_poly A = (aas. [:- a, 1:])" by auto
  from jordan_nf_exists[OF A this] obtain n_as where jnfA: "jordan_nf A n_as" ..
  with id have jnfB: "jordan_nf B n_as" by simp
  from jnfA jnfB show "similar_mat A B" 
    unfolding jordan_nf_def using similar_mat_trans similar_mat_sym by blast
qed

lemma order_char_poly_smult: fixes A :: "complex mat" 
  assumes A: "A  carrier_mat n n" 
  and k: "k  0" 
shows "order x (char_poly (k m A)) = order (x / k) (char_poly A)" 
proof -
  from char_poly_factorized[OF A] obtain as where "char_poly A = (aas. [:- a, 1:])" by auto
  from jordan_nf_exists[OF A this] obtain n_as where jnf: "jordan_nf A n_as" ..
  show ?thesis unfolding jordan_nf_order[OF jnf] jordan_nf_order[OF jordan_nf_smult[OF jnf k]]
    by (induct n_as, auto simp: k)
qed

subsection ‹Application for Complexity›

text ‹We can estimate the complexity via the multiplicity of the eigenvalues with norm 1.›
lemma factored_char_poly_norm_bound_cof: assumes A: "A  carrier_mat n n"
  and linear_factors: "char_poly A = ( (a :: 'a :: {conjugatable_ordered_field, real_normed_field})  as. [:- a, 1:])"
  and le_1: " a. a  set as  norm a  1"
  and le_N: " a. a  set as  norm a = 1  length (filter ((=) a) as)  N"
  shows " c1 c2.  k. norm_bound (A ^m k) (c1 + c2 * of_nat k ^ (N - 1))"
  by (rule factored_char_poly_norm_bound[OF A linear_factors jordan_nf_exists[OF A linear_factors] le_1 le_N])


text ‹If we have an upper triangular matrix, then EVs are exactly the entries on the diagonal.
  So then we don't need to explicitly compute the characteristic polynomial.›
lemma counting_ones_complexity: 
  fixes A :: "'a :: real_normed_field mat"
  assumes A: "A  carrier_mat n n"
  and upper_t: "upper_triangular A"
  and le_1: " a. a  set (diag_mat A)  norm a  1"
  and le_N: " a. a  set (diag_mat A)  norm a = 1  length (filter ((=) a) (diag_mat A))  N"
  shows " c1 c2.  k. norm_bound (A ^m k) (c1 + c2 * of_nat k ^ (N - 1))"
proof -
  from triangular_to_jnf_vector[OF A upper_t] have jnf: " n_as. jordan_nf A n_as" ..
  show ?thesis
    by (rule factored_char_poly_norm_bound[OF A char_poly_upper_triangular[OF A upper_t] jnf le_1 le_N])
qed

text ‹If we have an upper triangular matrix $A$ then we can compute a JNF-vector of it.
  If this vector does not contain entries $(n,ev)$ with $ev$ being larger 1,
  then the growth rate of $A^k$ can be restricted by ${\cal O}(k^{N-1})$ 
  where $N$ is the maximal value for $n$, where $(n,|ev| = 1)$ occurs in the vector, i.e.,
  the size of the largest Jordan Block with Eigenvalue of norm 1.
  This method gives a precise complexity bound.›
lemma compute_jnf_complexity: 
  assumes A: "A  carrier_mat n n"
  and upper_t: "upper_triangular (A :: 'a :: real_normed_field mat)"
  and le_1: " n a. (n,a)  set (triangular_to_jnf_vector A)  norm a  1"
  and le_N: " n a. (n,a)  set (triangular_to_jnf_vector A)  norm a = 1  n  N"
  shows " c1 c2.  k. norm_bound (A ^m k) (c1 + c2 * of_nat k ^ (N - 1))"
proof -
  let ?jnf = "triangular_to_jnf_vector A"
  from triangular_to_jnf_vector[OF A upper_t] have jnf: "jordan_nf A ?jnf" .
  show ?thesis
    by (rule jordan_nf_matrix_poly_bound[OF A le_1 le_N jnf])
qed

end

Theory Matrix_Impl

(*  
    Author:      René Thiemann 
                 Akihisa Yamada
    License:     BSD
*)
section ‹Code Equations for All Algorithms›

text ‹In this theory we load all executable algorithms, i.e., Gauss-Jordan, determinants,
  Jordan normal form computation, etc., and perform some basic tests.›

theory Matrix_Impl
imports 
  Matrix_IArray_Impl
  Gauss_Jordan_IArray_Impl
  Determinant_Impl
  Show_Matrix
  Jordan_Normal_Form_Existence
  Show.Show_Instances
begin

text ‹For determinants we require class @{class idom_divide}, so integers, rationals, etc. can be used.›
value[code] "det (mat_of_rows_list 4 [[1 :: int, 4, 9, -1], [-3, -1, 5, 4], [4, 2, 0,2], [8,-9, 5,7]])"
value[code] "det (mat_of_rows_list 4 [[1 :: rat, 4, 9, -1], [-3, -1, 5, 4], [4, 2, 0,2], [8,-9, 5,7]])"

text ‹Since polynomials require @{class field} elements to be in class @{class idom_divide}, the implementation
  of characteristic polynomials is not applicable for integer matrices, but it is for rational and real matrices.›

value[code] "char_poly (mat_of_rows_list 4 [[1 :: real, 4, 9, -1], [-3, -1, 5, 4], [4, 2, 0,2], [8,-9, 5,7]])"

text ‹Also Jordan normal form computation requires matrices over @{class field} entries.›

value[code] "triangular_to_jnf_vector (mat_of_rows_list 6 [
  [3,4,1,4,7,18], 
  [0,3,0,8,9,4], 
  [0,0,3,2,0,4], 
  [0,0,0,5,17,7],
  [0,0,0,0,5,3], 
  [0,0,0,0,0,3 :: rat]])"

value[code] "show (mat_of_rows_list 3 [[1, 4, 5], [3, 6, 8]] * mat 3 4 (λ (i,j). i + 2 * j))"

text ‹Inverses can only be computed for matrices over fields.›

value[code] "show (mat_inverse (mat_of_rows_list 4 [[1 :: rat, 4, 9, -1], [-3, -1, 5, 4], [4, 2, 0,2], [8,-9, 5,7]]))"

value[code] "show (mat_inverse (mat_of_rows_list 4 [[1 :: rat, 4, 9, -1], [-3, -1, 5, 4], [-2, 3,14,3], [8,-9, 5,7]]))"

end

Theory Strassen_Algorithm

(*  
    Author:      René Thiemann 
                 Akihisa Yamada
    License:     BSD
*)
section ‹Strassen's algorithm for matrix multiplication.›

text ‹We define the algorithm for arbitrary matrices over rings,
  where an alignment of the dimensions to even numbers will 
  be performed throughout the algorithm.›

theory Strassen_Algorithm
imports 
  Matrix
begin

text ‹With @{const four_block_mat} and @{const split_block} we can define Strassen's 
  multiplication algorithm.›

text ‹We start with a simple heuristic on when to switch to the basic algorithm.›

definition strassen_constant :: nat where
  [code_unfold]: "strassen_constant = 20"

definition "strassen_too_small A B  
  dim_row A < strassen_constant  
  dim_col A < strassen_constant  
  dim_col B < strassen_constant"

text ‹We have to make a case analysis on whether all dimensions are even.›
definition "strassen_even A B  even (dim_row A)  even (dim_col A)  even (dim_col B)"

text ‹And then we can define the algorithm.›

function strassen_mat_mult :: "'a :: ring mat  'a mat  'a mat" where
  "strassen_mat_mult A B = (let nr = dim_row A; n = dim_col A; nc = dim_col B in
    if strassen_too_small A B then A * B else 
    if strassen_even A B then let
      nr2 = nr div 2;
      n2 = n div 2;
      nc2 = nc div 2;
      (A1,A2,A3,A4) = split_block A nr2 n2;
      (B1,B2,B3,B4) = split_block B n2 nc2;
      M1 = strassen_mat_mult (A1 + A4) (B1 + B4);
      M2 = strassen_mat_mult (A3 + A4) B1;
      M3 = strassen_mat_mult A1 (B2 - B4);
      M4 = strassen_mat_mult A4 (B3 - B1);
      M5 = strassen_mat_mult (A1 + A2) B4;
      M6 = strassen_mat_mult (A3 - A1) (B1 + B2);
      M7 = strassen_mat_mult (A2 - A4) (B3 + B4);
      C1 = M1 + M4 - M5 + M7;
      C2 = M3 + M5;
      C3 = M2 + M4;
      C4 = M1 - M2 + M3 + M6
    in four_block_mat C1 C2 C3 C4 else 
    let 
     nr' = (nr div 2) * 2;
     n' = (n div 2) * 2;
     nc' = (nc div 2) * 2;
     (A1,A2,A3,A4) = split_block A nr' n';
     (B1,B2,B3,B4) = split_block B n' nc';
     C1 = strassen_mat_mult A1 B1 + A2 * B3;
     C2 = A1 * B2 + A2 * B4;
     C3 = A3 * B1 + A4 * B3;
     C4 = A3 * B2 + A4 * B4
     in four_block_mat C1 C2 C3 C4)"
  by pat_completeness auto

text ‹For termination, we use the following measure.›

definition "strassen_measure  λ (A,B). (dim_row A + dim_col A + dim_col B)
  + (dim_row A + dim_col A + dim_col B) + (if strassen_even A B then 0 else 1)"

lemma strassen_measure_add[simp]: 
  "strassen_measure (A + B, C) = strassen_measure (B,C)" 
  "strassen_measure (A, B + C) = strassen_measure (A,C)" 
  "strassen_measure (A - B, C) = strassen_measure (B,C)" 
  "strassen_measure (A, B - C) = strassen_measure (A,C)" 
  "strassen_measure (- A, B) = strassen_measure (A,B)"
  "strassen_measure (A, - B) = strassen_measure (A,B)"
  unfolding strassen_measure_def strassen_even_def by auto

lemma strassen_measure_div_2: assumes "(A1, A2, A3, A4) = split_block A (dim_row A div 2) (dim_col A div 2)"
  "(B1, B2, B3, B4) = split_block B (dim_col A div 2) (dim_col B div 2)"  
  and large: "¬ strassen_too_small A B"
  shows 
  "strassen_measure (A1,B4) < strassen_measure (A,B)"
  "strassen_measure (A1,B2) < strassen_measure (A,B)"
  "strassen_measure (A2,B4) < strassen_measure (A,B)"
  "strassen_measure (A3,B2) < strassen_measure (A,B)"
  "strassen_measure (A4,B1) < strassen_measure (A,B)"
  "strassen_measure (A4,B3) < strassen_measure (A,B)"
  "strassen_measure (A4,B4) < strassen_measure (A,B)"
proof -
  {
    fix Ai Bi
    assume Ai: "Ai  {A1,A2,A3,A4}" and Bi: "Bi  {B1,B2,B3,B4}"
    from large[unfolded strassen_too_small_def strassen_constant_def]
    have "¬ dim_row A < 2" by auto 
    with assms Ai Bi have Ar:
      "dim_row Ai < dim_row A"
      "dim_col Ai  dim_col A"
      "dim_col Bi  dim_col B" 
      unfolding split_block_def Let_def by auto    
    hence "strassen_measure (Ai,Bi) < strassen_measure (A,B)"
      unfolding strassen_measure_def split by auto
  }
  thus
    "strassen_measure (A1,B2) < strassen_measure (A,B)"
    "strassen_measure (A1,B4) < strassen_measure (A,B)"
    "strassen_measure (A2,B4) < strassen_measure (A,B)"
    "strassen_measure (A3,B2) < strassen_measure (A,B)"
    "strassen_measure (A4,B1) < strassen_measure (A,B)"
    "strassen_measure (A4,B3) < strassen_measure (A,B)"
    "strassen_measure (A4,B4) < strassen_measure (A,B)"
    by auto
qed

lemma strassen_measure_odd: assumes "(A1, A2, A3, A4) = split_block A ((dim_row A div 2) * 2) ((dim_col A div 2) * 2)"  
  and "(B1, B2, B3, B4) = split_block B ((dim_col A div 2) * 2) ((dim_col B div 2) * 2)"
  and odd: "¬ strassen_even A B"
  shows "strassen_measure (A1,B1) < strassen_measure (A,B)"
proof -
  from assms have Ar:
    "dim_row A1 < dim_row A  dim_row A1 = dim_row A  even (dim_row A)" 
    unfolding split_block_def Let_def by auto presburger
  from assms have Ac:
    "dim_col A1 < dim_col A  dim_col A1 = dim_col A  even (dim_col A)" 
    unfolding split_block_def Let_def by auto presburger
  from assms have Bc:
    "dim_col B1 < dim_col B  dim_col B1 = dim_col B  even (dim_col B)" 
    unfolding split_block_def Let_def by auto presburger
  from Ar Ac Bc odd show ?thesis unfolding strassen_measure_def strassen_even_def split
    by (auto split: if_splits)
qed

termination by (relation "measure strassen_measure", 
   auto elim: strassen_measure_div_2 strassen_measure_odd)


lemma strassen_mat_mult: 
  "dim_col A = dim_row B  strassen_mat_mult A B = A * B"
proof (induct A B rule: strassen_mat_mult.induct)
  case (1 A B)
  let ?nr = "dim_row A"
  let ?nc = "dim_col B"
  let ?n = "dim_col A"
  show ?case
  proof (cases "strassen_too_small A B")
    case False note large = this
    let ?smm = strassen_mat_mult
    note IH = 1(1-8)[OF refl refl refl False _ refl refl refl _ refl refl refl _ refl refl refl]
    show ?thesis
    proof (cases "strassen_even A B")
      case True
      note even = True[unfolded strassen_even_def]
      let ?nr2 = "?nr div 2"
      let ?n2 = "?n div 2"
      let ?nc2 = "?nc div 2"
      from even have nr: "?nr = ?nr2 + ?nr2" by presburger 
      from even have n: "?n = ?n2 + ?n2" by presburger 
      from even have nc: "?nc = ?nc2 + ?nc2" by presburger 
      from 1(9) even have n': "dim_row B = ?n2 + ?n2"
        by auto
      obtain A1 A2 A3 A4 where splitA: 
        "split_block A ?nr2 ?n2 = (A1,A2,A3,A4)" by (rule prod_cases4)
      obtain B1 B2 B3 B4 where splitB: 
        "split_block B ?n2 ?nc2 = (B1,B2,B3,B4)" by (rule prod_cases4)
      note IH = IH(1-7)[OF True splitA[symmetric] splitB[symmetric] ]
      from split_block[OF splitA nr n]
      have blockA: "A = four_block_mat A1 A2 A3 A4"
        and A1: "A1  carrier_mat ?nr2 ?n2" 
        and A2: "A2  carrier_mat ?nr2 ?n2" 
        and A3: "A3  carrier_mat ?nr2 ?n2" 
        and A4: "A4  carrier_mat ?nr2 ?n2" 
        by blast+
      from split_block[OF splitB n' nc]
      have blockB: "B = four_block_mat B1 B2 B3 B4"
        and B1: "B1  carrier_mat ?n2 ?nc2" 
        and B2: "B2  carrier_mat ?n2 ?nc2" 
        and B3: "B3  carrier_mat ?n2 ?nc2" 
        and B4: "B4  carrier_mat ?n2 ?nc2" 
        by blast+
      note carr = A1 A2 A3 A4 B1 B2 B3 B4
      let ?M11 = "A1 + A4" let ?M12 = "B1 + B4"
      let ?M21 = "A3 + A4" let ?M22 = "B1"
      let ?M31 = "A1" let ?M32 = "B2 - B4"
      let ?M41 = "A4" let ?M42 = "B3 - B1"
      let ?M51 = "A1 + A2" let ?M52 = "B4"
      let ?M61 = "A3 - A1" let ?M62 = "B1 + B2"
      let ?M71 = "A2 - A4" let ?M72 = "B3 + B4"
      let ?M1 = "?smm ?M11 ?M12"
      let ?M2 = "?smm ?M21 ?M22"
      let ?M3 = "?smm ?M31 ?M32"
      let ?M4 = "?smm ?M41 ?M42"
      let ?M5 = "?smm ?M51 ?M52"
      let ?M6 = "?smm ?M61 ?M62"
      let ?M7 = "?smm ?M71 ?M72"
      let ?C1 = "?M1 + ?M4 - ?M5 + ?M7"
      let ?C2 = "?M3 + ?M5"
      let ?C3 = "?M2 + ?M4"
      let ?C4 = "?M1 - ?M2 + ?M3 + ?M6"
      have res: "?smm A B = four_block_mat ?C1 ?C2 ?C3 ?C4"
        using large True
        unfolding strassen_mat_mult.simps[of A B] Let_def splitA splitB split
        by auto
      have M1: "?M1 = ?M11 * ?M12"
        by (rule IH(1), insert carr, auto)
      note IH = IH(2-)[OF refl]
      have M2: "?M2 = ?M21 * ?M22"
        by (rule IH(1), insert carr, auto)
      note IH = IH(2-)[OF refl]
      have M3: "?M3 = ?M31 * ?M32"
        by (rule IH(1), insert carr, auto)
      note IH = IH(2-)[OF refl]
      have M4: "?M4 = ?M41 * ?M42"
        by (rule IH(1), insert carr, auto)
      note IH = IH(2-)[OF refl]
      have M5: "?M5 = ?M51 * ?M52"
        by (rule IH(1), insert carr, auto)
      note IH = IH(2-)[OF refl]
      have M6: "?M6 = ?M61 * ?M62"
        by (rule IH(1), insert carr, auto)
      note IH = IH(2-)[OF refl]
      have M7: "?M7 = ?M71 * ?M72"
        by (rule IH(1), insert carr, auto)
      note distr = 
        add_mult_distrib_mat[of _ ?nr2 ?n2 _ _ ?nc2]
        minus_mult_distrib_mat[of _ ?nr2 ?n2 _ _ ?nc2]
        mult_add_distrib_mat[of _ ?nr2 ?n2 _ ?nc2]
        mult_minus_distrib_mat[of _ ?nr2 ?n2 _ ?nc2]
      note closed = add_carrier_mat[of _ ?nr2 ?nc2]
         uminus_carrier_iff_mat[of _ ?nr2 ?nc2]
      note ac = assoc_add_mat[of _ ?nr2 ?nc2] comm_add_mat[of _ ?nr2 ?nc2]
      show ?thesis unfolding res M1 M2 M3 M4 M5 M6 M7
        unfolding blockA blockB
          mult_four_block_mat[OF carr]
        by (rule cong_four_block_mat)
           (insert carr, auto simp: distr ac closed)
    next
      case False
      let ?nr2 = "?nr div 2 * 2" let ?nr2' = "?nr - ?nr2"
      let ?n2 = "?n div 2 * 2"   let ?n2' = "?n - ?n2"
      let ?nc2 = "?nc div 2 * 2" let ?nc2' = "?nc - ?nc2"
      have nr: "?nr = ?nr2 + ?nr2'" by presburger 
      have n: "?n = ?n2 + ?n2'" by presburger 
      have nc: "?nc = ?nc2 + ?nc2'" by presburger 
      from 1(9) have n': "dim_row B = ?n2 + ?n2'" by auto   
      obtain A1 A2 A3 A4 where splitA: 
        "split_block A ?nr2 ?n2 = (A1,A2,A3,A4)" by (rule prod_cases4)
      obtain B1 B2 B3 B4 where splitB: 
        "split_block B ?n2 ?nc2 = (B1,B2,B3,B4)" by (rule prod_cases4)
      note IH = IH(8)[OF False splitA[symmetric] splitB[symmetric]]
      from split_block[OF splitA nr n]
      have blockA: "A = four_block_mat A1 A2 A3 A4"
        and A1: "A1  carrier_mat ?nr2 ?n2" 
        and A2: "A2  carrier_mat ?nr2 ?n2'" 
        and A3: "A3  carrier_mat ?nr2' ?n2" 
        and A4: "A4  carrier_mat ?nr2' ?n2'" 
        by blast+
      from split_block[OF splitB n' nc]
      have blockB: "B = four_block_mat B1 B2 B3 B4"
        and B1: "B1  carrier_mat ?n2 ?nc2" 
        and B2: "B2  carrier_mat ?n2 ?nc2'" 
        and B3: "B3  carrier_mat ?n2' ?nc2" 
        and B4: "B4  carrier_mat ?n2' ?nc2'" 
        by blast+      
      note carr = A1 A2 A3 A4 B1 B2 B3 B4
      from carr have "dim_col A1 = dim_row B1" by simp
      note IH = IH[OF this]
      have "?smm A B = four_block_mat 
        (A1 * B1 + A2 * B3)
        (A1 * B2 + A2 * B4)
        (A3 * B1 + A4 * B3)
        (A3 * B2 + A4 * B4)"
        unfolding strassen_mat_mult.simps[of A B] Let_def 
          splitA splitB split IH using large False by auto
      also have " = A * B"
        unfolding blockA blockB
         mult_four_block_mat[OF carr] by simp
      finally show ?thesis by simp
    qed
  qed simp
qed

end

Theory Strassen_Algorithm_Code

(*  
    Author:      René Thiemann 
                 Akihisa Yamada
    License:     BSD
*)
section ‹Strassen's Algorithm as Code Equation›

text ‹We replace the code-equations for matrix-multiplication 
  by Strassen's algorithm. Note that this will strengthen the class-constraint
  for matrix multiplication from semirings to rings!›

theory Strassen_Algorithm_Code
imports 
  Strassen_Algorithm
begin

text ‹The aim is to replace the implementation of @{thm times_mat_def} by @{const strassen_mat_mult}.›

text ‹We first need a copy of standard matrix multiplication to execute the base case.›

definition "basic_mat_mult = (*)"
lemma basic_mat_mult_code[code]: "basic_mat_mult A B = mat (dim_row A) (dim_col B) (λ (i,j). row A i  col B j)"
  unfolding basic_mat_mult_def by auto

text ‹Next use this new matrix multiplication code within Strassen's algorithm.›
lemmas strassen_mat_mult_code[code] = strassen_mat_mult.simps[folded basic_mat_mult_def]

text ‹And finally use Strassen's algorithm for implementing matrix-multiplication.›

lemma mat_mult_code[code]: "A * B = (if dim_col A = dim_row B then strassen_mat_mult A B else basic_mat_mult A B)"
  using strassen_mat_mult[of A B] unfolding basic_mat_mult_def by auto

end

Theory Matrix_Comparison

(*  
    Author:      René Thiemann 
                 Akihisa Yamada
    License:     BSD
*)
section ‹Comparison of Matrices›

text ‹We use matrices over ordered semirings to again define ordered semirings.
  There are two instances, one for ordinary semirings (where addition is monotone w.r.t. the
  strict ordering in a single argument); 
  and one for semirings like the arctic one, where addition is interpreted
  as maximum, and therefore monotonicity of the strict ordering in a single argument is no
  longer provided. 

  Both ordered semirings are used for checking termination proofs, where at the moment only
  the ordinary semirings is supported for checking complexity proofs.›

theory Matrix_Comparison
imports 
  Matrix
  Matrix.Ordered_Semiring
begin

context ord
begin
definition mat_ge :: "'a mat  'a mat  bool" (infix "m" 50) where
  "A m B = ( i < dim_row A.  j < dim_col A. A $$ (i,j)  B $$ (i,j))"

lemma mat_geI[intro]: assumes "A  carrier_mat nr nc" 
  " i j. i < nr  j < nc  A $$ (i,j)  B $$ (i,j)"
  shows "A m B"
  using assms unfolding mat_ge_def by auto

lemma mat_geD[dest]: assumes "A m B" and "i < dim_row A" "j < dim_col A"
  shows "A $$ (i,j)  B $$ (i,j)" 
  using assms unfolding mat_ge_def by auto

definition mat_gt :: "('a  'a  bool)  nat  'a mat  'a mat  bool" where
  "mat_gt gt sd A B = (A m B  ( i < sd.  j < sd. gt (A $$ (i,j)) (B $$ (i,j))))"

lemma mat_gtI[intro]: assumes "A m B"
  and "i < sd" "j < sd" "gt (A $$ (i,j)) (B $$ (i,j))"
  shows "mat_gt gt sd A B"
  using assms unfolding mat_gt_def by auto

lemma mat_gtD[dest]: assumes "mat_gt gt sd A B"
  shows "A m B" " i < sd.  j < sd. gt (A $$ (i,j)) (B $$ (i,j))"
  using assms unfolding mat_gt_def by auto

definition mat_max :: "'a mat  'a mat  'a mat" ("maxm") where
  "maxm A B = mat (dim_row A) (dim_col A) (λ ij. max (A $$ ij) (B $$ ij))"

lemma mat_max_carrier[simp]:
  "maxm A B  carrier_mat (dim_row A) (dim_col A)"
  unfolding mat_max_def by auto

lemma mat_max_closed[intro]:
  "A  carrier_mat nr nc  B  carrier_mat nr nc  maxm A B  carrier_mat nr nc"
  unfolding mat_max_def by auto

lemma mat_max_index:
  assumes "i < dim_row A" "j < dim_col A"
  shows "(mat_max A B) $$ (i,j) = max (A $$ (i,j)) (B $$ (i,j))"
  unfolding mat_max_def using index_mat assms by auto

definition (in zero) mat_default :: "'a  nat  'a mat" ("defaultm") where
  "defaultm d n = mat n n (λ (i,j). if i = j then d else 0)"

lemma mat_default_carrier[simp]: "defaultm d n  carrier_mat n n"
  unfolding mat_default_def by auto
end


definition mat_mono :: "('a  bool)  nat  'a mat  bool"
where "mat_mono P sd A = ( j < sd.  i < sd. P (A $$ (i,j)))"

context non_strict_order
begin
lemma mat_ge_trans: assumes "A m B" "B m C"
  and "A  carrier_mat nr nc" "B  carrier_mat nr nc"
shows "A m C"
  using assms ge_trans[of "B $$ (i,j)" "A $$ (i,j)" for i j] 
  unfolding mat_ge_def carrier_mat_def by auto

lemma mat_ge_refl: "A m A"
  unfolding mat_ge_def by (auto simp: ge_refl)

lemma mat_max_comm: "A  carrier_mat nr nc  B  carrier_mat nr nc  maxm A B = maxm B A"
  unfolding mat_max_def by (intro eq_matI, auto simp: max_comm)

lemma mat_max_ge: "maxm A B m A"
  unfolding mat_max_def by (intro mat_geI[of _ "dim_row A" "dim_col A"], auto)

lemma mat_max_ge_0: "A  carrier_mat nr nc  B  carrier_mat nr nc  A m B  maxm A B = A"
  unfolding mat_max_def by (intro eq_matI, auto simp: max_id)

lemma mat_max_mono: "A m B 
   A  carrier_mat nr nc  B  carrier_mat nr nc  C  carrier_mat nr nc  
   maxm C A m maxm C B"
  by (intro mat_geI[of _ nr nc], auto simp: max_mono mat_max_def)
end

lemma mat_plus_left_mono: "A m (B :: 'a :: ordered_ab_semigroup mat) 
   A  carrier_mat nr nc  B  carrier_mat nr nc  C  carrier_mat nr nc 
   A + C m B + C"
  by (intro mat_geI[of _ nr nc], auto simp: plus_left_mono)

lemma mat_plus_right_mono: "B m (C :: 'a :: ordered_ab_semigroup mat) 
   A  carrier_mat nr nc  B  carrier_mat nr nc  C  carrier_mat nr nc 
   A + B m A + C"
  by (intro mat_geI[of _ nr nc], auto simp: plus_right_mono)

lemma plus_mono: "x1  (x2 :: 'a :: ordered_ab_semigroup)  
  y1  y2  x1 + y1  x2 + y2"
  using ge_trans[OF plus_left_mono[of x2 x1] plus_right_mono[of y2 y1]] .

text ‹Since one cannot use @{thm sum_mono} (it requires other 
  class constraints like @{class order}), we make our own copy of this
  fact.›

lemma sum_mono_ge:
  assumes ge: "i. iK  f (i::'a)  ((g i)::('b::ordered_semiring_0))"
  shows "(iK. f i)  (iK. g i)"
proof (cases "finite K")
  case True
  thus ?thesis using ge
  proof induct
    case empty
    show ?case by (simp add: ge_refl)
  next
    case insert
    thus ?case using plus_mono by fastforce
  qed
next
  case False then show ?thesis by (simp add: ge_refl)
qed

lemma (in one_mono_ordered_semiring_1) sum_mono_gt:
  assumes le: "i. iK  f (i::'b)  ((g i)::'a)"
  and i: "i  K"
  and gt: "f i  g i"
  and K: "finite K"
  shows "(iK. f i)  (iK. g i)"
proof -
  have id: " f. (iK. f i) = f i + (i K - {i}. f i)"
    by (rule sum.remove[OF K i])
  have ge: "(i K - {i}. f i)  (i K - {i}. g i)"
    by (rule sum_mono_ge[OF le], auto)
  show ?thesis unfolding id using compat[OF plus_right_mono[OF ge] plus_gt_left_mono[OF gt]] .
qed

lemma scalar_left_mono: assumes 
  "u  carrier_vec n" "v  carrier_vec n" "w  carrier_vec n" 
  and " i. i < n  u $ i  v $ i"
  and " i. i < n  w $ i  (0 :: 'a :: ordered_semiring_0)"
  shows "u  w  v  w" unfolding scalar_prod_def
  by (intro sum_mono_ge times_left_mono, insert assms, auto)

lemma scalar_right_mono: assumes 
  "u  carrier_vec n" "v  carrier_vec n" "w  carrier_vec n" 
  and " i. i < n  v $ i  w $ i"
  and " i. i < n  u $ i  (0 :: 'a :: ordered_semiring_0)"
  shows "u  v  u  w" 
proof -
  have dim: "dim_vec v = dim_vec w" using assms by auto
  show ?thesis unfolding scalar_prod_def dim
    by (intro sum_mono_ge times_right_mono, insert assms, auto)
qed

lemma mat_mult_left_mono: assumes C0: "C m 0m n n"
  and AB: "A m (B :: 'a :: ordered_semiring_0 mat)"
  and carr: "A  carrier_mat n n" "B  carrier_mat n n" "C  carrier_mat n n"
  shows "A * C m B * C"
proof -
  {
    fix i j
    assume i: "i < n" "j < n"
    have "row A i  col C j  row B i  col C j"
      by (rule scalar_left_mono[of _ n], insert C0 AB carr i, auto)
  }
  thus ?thesis 
    by (intro mat_geI[of _ n n], insert carr, auto)
qed

lemma mat_mult_right_mono: assumes A0: "A m 0m n n" 
  and BC: "B m (C :: 'a :: ordered_semiring_0 mat)"
  and carr: "A  carrier_mat n n" "B  carrier_mat n n" "C  carrier_mat n n"
  shows "A * B m A * C"
proof -
  {
    fix i j
    assume i: "i < n" "j < n"
    have "row A i  col B j  row A i  col C j"
      by (rule scalar_right_mono[of _ n], insert A0 BC carr i, auto)
  }
  thus ?thesis 
    by (intro mat_geI[of _ n n], insert carr, auto)
qed

lemma one_mat_ge_zero: "(1m n :: 'a :: ordered_semiring_1 mat) m 0m n n"
  by (intro mat_geI[of _ n n], auto simp: one_ge_zero ge_refl)

context order_pair
begin
lemma mat_ge_gt_trans: assumes sd: "sd  n" and AB: "A m B" and BC: "mat_gt gt sd B C"
  and A: "A  carrier_mat n n" and B: "B  carrier_mat n n"
shows "mat_gt gt sd A C"
proof -  
  from mat_gtD[OF BC] obtain i j where ij: "i < sd" "j < sd" and gt: "B $$ (i, j)  C $$ (i, j)" 
    and BC: "B m C" by auto
  from mat_ge_trans[OF AB BC A B] have AC: "A m C" .
  from mat_geD[OF AB, of i j] A sd ij have ge: "A $$ (i,j)  B $$ (i,j)" by auto  
  from compat[OF ge gt] have gt: "A $$ (i, j)  C $$ (i, j)" .
  with ij AC show ?thesis by auto
qed

lemma mat_gt_ge_trans: assumes sd: "sd  n" and AB: "mat_gt gt sd A B" and BC: "B m C"
  and A: "A  carrier_mat n n" and B: "B  carrier_mat n n"
shows "mat_gt gt sd A C"
proof -  
  from mat_gtD[OF AB] obtain i j where ij: "i < sd" "j < sd" and gt: "A $$ (i, j)  B $$ (i, j)" 
    and AB: "A m B" by auto
  from mat_ge_trans[OF AB BC A B] have AC: "A m C" .
  from mat_geD[OF BC, of i j] B sd ij have ge: "B $$ (i,j)  C $$ (i,j)" by auto  
  from compat2[OF gt ge] have gt: "A $$ (i, j)  C $$ (i, j)" .
  with ij AC show ?thesis by auto
qed

lemma mat_gt_imp_mat_ge: "mat_gt gt sd A B  A m B"
  by (rule mat_gtD)

lemma mat_gt_trans: assumes sd: "sd  n" and AB: "mat_gt gt sd A B" and BC: "mat_gt gt sd B C"
  and A: "A  carrier_mat n n" and B: "B  carrier_mat n n"
shows "mat_gt gt sd A C"
  using mat_ge_gt_trans[OF sd mat_gt_imp_mat_ge[OF AB] BC A B] .

lemma mat_default_ge_0: "defaultm default n m 0m n n"
  by (intro mat_geI[of _ n n], auto simp: mat_default_def default_ge_zero ge_refl)
end

definition mat_ordered_semiring :: "nat  nat  ('a :: ordered_semiring_1  'a  bool)  'b  ('a mat,'b) ordered_semiring_scheme" where
  "mat_ordered_semiring n sd gt b  ring_mat TYPE('a) n 
    ordered_semiring.geq = (≥m),
    gt = mat_gt gt sd,
    max = maxm,
     = b"

lemma (in one_mono_ordered_semiring_1) mat_ordered_semiring: assumes sd_n: "sd  n" 
  shows "ordered_semiring 
    (mat_ordered_semiring n sd (≻) b :: ('a mat,'b) ordered_semiring_scheme)" 
  (is "ordered_semiring ?R")
proof -
  interpret semiring ?R unfolding mat_ordered_semiring_def by (rule semiring_mat)
  show ?thesis 
    by (unfold_locales, unfold ring_mat_def mat_ordered_semiring_def ordered_semiring_record_simps,
    auto intro: mat_ge_trans mat_ge_refl mat_ge_gt_trans[OF sd_n] mat_gt_ge_trans[OF sd_n] mat_max_comm
    mat_max_ge mat_max_ge_0 mat_max_mono one_mat_ge_zero mat_gt_trans[OF sd_n] mat_gt_imp_mat_ge
    mat_plus_left_mono mat_mult_left_mono mat_mult_right_mono)
qed

context weak_SN_strict_mono_ordered_semiring_1
begin

lemma weak_mat_gt_mono: assumes sd_n: "sd  n" and
    orient: " A B. A  carrier_mat n n  B  carrier_mat n n  (A,B)  set ABs  mat_gt weak_gt sd A B"
   shows " gt. SN_strict_mono_ordered_semiring_1 default gt mono  
     ( A B. A  carrier_mat n n  B  carrier_mat n n  (A, B)  set ABs  mat_gt gt sd A B)"
proof -
  let ?n = "[0 ..< n]"
  let ?m1x = "[ A $$ (i,j) . A <- map fst ABs, i <- ?n, j <- ?n]"
  let ?m2y = "[ B $$ (i,j) . B <- map snd ABs, i <- ?n, j <- ?n]"
  let ?pairs = "concat (map (λ x. map (λ y. (x,y)) ?m2y) ?m1x)"
  let ?strict = "filter (λ (x,y). weak_gt x y) ?pairs"
  have " x y. (x,y)  set ?strict  weak_gt x y" by auto
  from weak_gt_mono[OF this] obtain gt where order: "SN_strict_mono_ordered_semiring_1 default gt mono" 
    and orient2: " x y. (x, y)  set ?strict  gt x y" by auto
  show ?thesis
  proof (intro exI allI conjI impI, rule order)
    fix A B
    assume A: "A  carrier_mat n n" and B: "B  carrier_mat n n"
      and AB: "(A, B)  set ABs"          
    from orient[OF this] have "mat_gt weak_gt sd A B" by auto
    from mat_gtD[OF this] obtain i j where
      ge: "A m B" and ij: "i < sd" "j < sd" and wgt: "weak_gt (A $$ (i,j)) (B $$ (i,j))"
      by auto
    from ij sd  n have ij': "i < n" "j < n" by auto
    have gt: "gt (A $$ (i,j)) (B $$ (i,j))"
      by (rule orient2, insert ij' AB wgt, force)
    show "mat_gt gt sd A B" using ij gt ge by auto
  qed
qed
end

lemma sum_mat_mono: 
  assumes A: "A  carrier_mat nr nc" and B: "B  carrier_mat nr nc" 
  and AB: "A m (B :: 'a :: ordered_semiring_0 mat)"
  shows "sum_mat A  sum_mat B"
proof -
  from A B have id: "dim_row B = dim_row A" "dim_col B = dim_col A" by auto
  show ?thesis unfolding sum_mat_def id
    by (rule sum_mono_ge, insert mat_geD[OF AB] id, auto)
qed

context one_mono_ordered_semiring_1
begin
lemma sum_mat_mono_gt: 
  assumes "sd  n"
  and A: "A  carrier_mat n n" and B: "B  carrier_mat n n"
  and AB: "mat_gt (≻) sd A (B :: 'a mat)"
  shows "sum_mat A  sum_mat B"
proof -
  from A B have id: "dim_row B = dim_row A" "dim_col B = dim_col A" by auto
  from mat_gtD[OF AB] obtain i j where AB: "A m B" and 
    ij: "i < sd" "j < sd" and gt: "A $$ (i,j)  B $$ (i,j)" by auto
  show ?thesis unfolding sum_mat_def id
    by (rule sum_mono_gt[of _ _ _ "(i,j)"], insert ij gt mat_geD[OF AB] A B sd  n, auto)
qed

lemma mat_plus_gt_left_mono: assumes sd_n: "sd  n" and gt: "mat_gt (≻) sd A B"  
  and A: "A  carrier_mat n n" and B: "B  carrier_mat n n" and C: "C  carrier_mat n n"
  shows "mat_gt (≻) sd (A + C) (B + C)"
proof -
  note wf = A B C
  from mat_gtD[OF gt] obtain i j 
    where AB: "A m B" and ij: "i < sd" "j < sd" and gt: "A $$ (i,j)  B $$ (i,j)" by auto
  from plus_gt_left_mono[OF gt, of "C $$ (i,j)"]
  show ?thesis
    by (intro mat_gtI[OF mat_geI[of _ n n] ij], insert mat_geD[OF AB] wf ij sd_n, auto intro: plus_left_mono)
qed

lemma mat_gt_ge_mono: "sd  n  mat_gt gt sd A B 
   mat_gt gt sd C D 
   A  carrier_mat n n 
   B  carrier_mat n n 
   C  carrier_mat n n 
   D  carrier_mat n n 
   mat_gt gt sd (A + C) (B + D)"
  by (rule mat_gt_ge_trans[OF _ mat_plus_gt_left_mono mat_plus_right_mono[OF mat_gt_imp_mat_ge]],
  auto)

lemma mat_default_gt_mat0: assumes sd_pos: "sd > 0" and sd_n: "sd  n"
  shows "mat_gt (≻) sd (defaultm default n) (0m n n)"
proof -
  from assms have n: "n > 0" by auto
  show ?thesis
    by (intro mat_gtI[OF mat_default_ge_0 sd_pos sd_pos], insert sd_pos sd_n, auto simp: mat_default_def default_gt_zero)
qed
end


context SN_one_mono_ordered_semiring_1
begin

abbreviation mat_s :: "'a mat  nat  nat  'a mat  bool" ("(_ m _ _ _)" [51,51,51,51] 50)
 where "A m n sd B  (A  carrier_mat n n  B  carrier_mat n n  B m 0m n n  mat_gt (≻) sd A B)"

lemma mat_gt_SN: assumes sd_n: "sd  n" shows "SN {(m1,m2) . m1 m n sd m2}"
proof 
  fix A
  assume " i. (A i, A (Suc i))  {(m1,m2). m1 m n sd m2}"
  hence " i. (A i, A (Suc i))  {(m1,m2). m1 m n sd m2}" by blast
  hence A: " i. A i  carrier_mat n n" 
    and ge: " i. A (Suc i) m 0m n n" 
    and gt: " i. mat_gt (≻) sd (A i) (A (Suc i))" by auto  
  define s where "s = (λ i. sum_mat (A i))"
  {
    fix i
    from sum_mat_mono_gt[OF sd_n A A gt[of i]]
    have gt: "s i  s (Suc i)" unfolding s_def .
    from sum_mat_mono[OF A _ ge[of i]]
    have ge: "s (Suc i)  0" unfolding s_def by auto
    note ge gt 
  }
  with SN show False by auto
qed
end

context SN_strict_mono_ordered_semiring_1
begin 

lemma mat_mono: assumes sd_n: "sd  n" and A: "A  carrier_mat n n" and B: "B  carrier_mat n n" and C: "C  carrier_mat n n" 
  and gt: "mat_gt (≻) sd B C" and gez: "A m 0m n n" and mmono: "mat_mono mono sd A"
  shows "mat_gt (≻) sd (A * B) (A * C)" (is "mat_gt _ _ ?AB ?AC")
proof -
  from mat_gtD[OF gt] obtain i j where 
    i: "i < sd" and j: "j < sd" and gt: "B $$ (i,j)  C $$ (i,j)" and BC: "B m C" by auto
  from mat_mult_right_mono[OF gez BC A B C] have ge: "?AB m ?AC" .
  from mmono[unfolded mat_mono_def] i obtain k where k: "k < sd" and mon: "mono (A $$ (k,i))" by auto
  from mat_geD[OF gez] k i sd_n A have "A $$ (k, i)  0" by auto
  note mono = mono[OF mon gt this]
  have id: "dim_vec (col B j) = n" "dim_vec (col C j) = n" using j sd_n B C by auto
  {
    fix i
    assume "i < n"
    hence "row A k $ i * col B j $ i  row A k $ i * col C j $ i"
      by (intro times_right_mono, insert j k sd_n A B C mat_geD[OF gez] mat_geD[OF BC], auto)
  } note sge = this
  have gt: "row A k  col B j  row A k  col C j" unfolding scalar_prod_def id
    by (rule sum_mono_gt[of _ _ _ i, OF sge], insert mono k i j A B C sd_n, auto)
  show ?thesis
    by (rule mat_gtI[OF ge k j], insert k j sd_n A B C gt, auto)
qed
end

definition mat_comp_all :: "('a  'a  bool)  'a mat  'a mat  bool"
where "mat_comp_all r A B =
   ( i < dim_row A.  j < dim_col A. r (A $$ (i,j)) (B $$ (i,j)))"

lemma mat_comp_allI:
  assumes "A  carrier_mat nr nc" "B  carrier_mat nr nc"
  and " i j. i < nr  j < nc  r (A $$(i,j)) (B $$ (i,j))"
  shows "mat_comp_all r A B"
  unfolding mat_comp_all_def using assms by simp

lemma mat_comp_allE:
  assumes "mat_comp_all r A B"
  and "A  carrier_mat nr nc" "B  carrier_mat nr nc"
  shows " i j. i < nr  j < nc  r (A $$ (i,j)) (B $$(i,j))"
  using assms unfolding mat_comp_all_def by auto

context weak_SN_both_mono_ordered_semiring_1
begin

abbreviation weak_mat_gt_arc :: "'a mat  'a mat  bool"
where "weak_mat_gt_arc  mat_comp_all weak_gt"

lemma weak_mat_gt_both_mono:
   assumes ABs: "set ABs  carrier_mat n n × carrier_mat n n"
   and orient: "(A,B)  set ABs. weak_mat_gt_arc A B"
   shows " gt. SN_both_mono_ordered_semiring_1 default gt arc_pos 
   ((A,B)  set ABs. mat_comp_all gt A B)"
proof -
  let ?pairs = "[ (fst AB $$ (i,j), snd AB $$ (i,j)) . AB <- ABs, i <- [0 ..< n], j <- [0 ..< n]]"
  let ?strict = "filter (λ (x,y). weak_gt x y) ?pairs"
  have " x y. (x,y)  set ?strict  weak_gt x y" by auto
  from weak_gt_both_mono[OF this]
  obtain gt
    where order: "SN_both_mono_ordered_semiring_1 default gt arc_pos"
    and orient2: " x y. (x, y)  set ?strict  gt x y"
    by auto
  {
    fix A B assume AB: "(A,B)  set ABs"
    hence A: "A  carrier_mat n n" and B: "B  carrier_mat n n"
      using AB ABs by auto
    have "mat_comp_all gt A B"
    proof (rule mat_comp_allI[OF A B])
      fix i j
      assume i: "i < n" and j: "j < n"
      from mat_comp_allE[OF _ A B this] orient AB
      have weak_gt: "weak_gt (A $$(i,j)) (B $$ (i,j))" (is "weak_gt ?x ?y") by auto
      have "(?x,?y)  set ?pairs" using A AB i j by force
      with weak_gt
      have gt: "(?x,?y)  set ?strict" by simp
      show "gt ?x ?y" by (rule orient2[OF gt])
    qed
  }
  hence "(A, B)set ABs. mat_comp_all gt A B" by auto
  thus ?thesis using order by auto
qed
end

definition mat_both_ordered_semiring :: "nat  ('a :: ordered_semiring_1  'a  bool)  'b  ('a mat,'b) ordered_semiring_scheme" where
  "mat_both_ordered_semiring n gt b  ring_mat TYPE('a) n 
    ordered_semiring.geq = mat_ge,
    gt = mat_comp_all gt,
    max = mat_max,
     = b"

(* checking whether a matrix is arctic positive (first entry is arctic positive) *)
definition mat_arc_posI :: "('a  bool)  'a mat  bool"
where "mat_arc_posI ap A  ap (A $$ (0,0))"

context both_mono_ordered_semiring_1
begin 

abbreviation mat_gt_arc :: "'a mat  'a mat  bool"
where "mat_gt_arc  mat_comp_all gt"

abbreviation mat_arc_pos :: "'a mat  bool"
where "mat_arc_pos  mat_arc_posI arc_pos"

lemma mat_max_id: fixes A :: "'a mat"
  assumes ge: "mat_ge A B"
  and A: "A  carrier_mat nr nc"
  and B: "B  carrier_mat nr nc"
  shows "mat_max A B = A"
  using mat_max_ge_0[OF A B ge] .

lemma mat_gt_arc_trans:
  assumes A_B: "mat_gt_arc A B"
  and B_C: "mat_gt_arc B C"
  and A: "A  carrier_mat nr nc"
  and B: "B  carrier_mat nr nc"
  and C: "C  carrier_mat nr nc"
  shows "mat_gt_arc A C"
proof (rule mat_comp_allI[OF A C])
  fix i j
  assume i: "i < nr" and j: "j < nc"
  from mat_comp_allE[OF A_B A B i j] mat_comp_allE[OF B_C B C i j]
  show "A $$ (i,j)  C $$ (i,j)" by (rule gt_trans)
qed

lemma mat_gt_arc_compat:
  assumes ge: "mat_ge A B"
  and gt: "mat_gt_arc B C"
  and A: "A  carrier_mat nr nc"
  and B: "B  carrier_mat nr nc"
  and C: "C  carrier_mat nr nc"
  shows "mat_gt_arc A C"
proof (rule mat_comp_allI[OF A C])
  fix i j assume i: "i < nr" and j: "j < nc"
  have "A $$ (i,j)  B $$ (i,j)" using ge A i j by auto
  also have "B $$ (i,j)  C $$ (i,j)"
    using mat_comp_allE[OF gt B C i j] by auto
  finally show "A $$ (i,j)  C $$ (i,j)" by auto
qed

lemma mat_gt_arc_compat2:
  assumes gt: "mat_gt_arc A B"
  and ge: "mat_ge B C"
  and A: "A  carrier_mat nr nc"
  and B: "B  carrier_mat nr nc"
  and C: "C  carrier_mat nr nc"
  shows "mat_gt_arc A C"
proof (rule mat_comp_allI[OF A C])
  fix i j assume i: "i < nr" and j: "j < nc"
  have "A $$ (i,j)  B $$ (i,j)"
    using mat_comp_allE[OF gt] A B i j by auto
  also have "B $$ (i,j)  C $$ (i,j)"
    using ge B i j by auto
  finally show "A $$ (i,j)  C $$ (i,j)" by auto
qed

lemma mat_gt_arc_imp_mat_ge:
  assumes gt: "mat_gt_arc A B"
  and A: "A  carrier_mat nr nc"
  and B: "B  carrier_mat nr nc"
  shows "mat_ge A B"
  using subst mat_geI[OF A]
  using mat_comp_allE[OF gt A B] gt_imp_ge by auto

lemma (in both_mono_ordered_semiring_1) mat_both_ordered_semiring: assumes n: "n > 0" 
  shows "ordered_semiring 
    (mat_both_ordered_semiring n (≻) b :: ('a mat,'b) ordered_semiring_scheme)" 
  (is "ordered_semiring ?R")
proof -
  interpret semiring ?R unfolding mat_both_ordered_semiring_def by (rule semiring_mat)
  show ?thesis 
    apply (unfold_locales)
    unfolding ring_mat_def mat_both_ordered_semiring_def ordered_semiring_record_simps
    apply(
      auto intro: mat_max_comm mat_ge_trans
      mat_plus_left_mono mat_mult_left_mono mat_mult_right_mono mat_ge_refl
      one_mat_ge_zero mat_max_mono mat_max_ge mat_max_id
      mat_gt_arc_trans mat_gt_arc_imp_mat_ge
      mat_gt_arc_compat mat_gt_arc_compat2)
    done
qed


lemma mat0_leastI:
  assumes A: "A  carrier_mat nr nc"
  shows "mat_gt_arc A (0m nr nc)"
proof (rule mat_comp_allI[OF A])
  fix i j
  assume i: "i < nr" and j: "j < nc"
  thus "A $$ (i,j)  0m nr nc $$ (i,j)" by (auto simp: zero_leastI)
qed auto

lemma mat0_leastII: 
  assumes gt: "mat_gt_arc (0m nr nc) A"
  and A: "A  carrier_mat nr nc"
  shows "A = 0m nr nc"
  apply (rule eq_matI)
  unfolding index_zero_mat
  using A
proof -
  fix i j
  assume i: "i < nr" and j: "j < nc"
  show "A $$ (i,j) = 0"
    using zero_leastII mat_comp_allE[OF gt _ A] i j by auto
qed auto

lemma mat0_leastIII:
  assumes A: "A  carrier_mat nr nc"
  shows "mat_ge A ((0m nr nc) :: 'a mat)"
proof (rule mat_geI[OF A]; unfold index_zero_mat)
  fix i j
  assume i: "i < nr" and j: "j < nc"
  show "A $$ (i,j)  0" using zero_leastIII by simp
qed

lemma mat_max_0_id: fixes A :: "'a mat"
  assumes A: "A  carrier_mat nr nc"
  shows "mat_max (0m nr nc) A = A"
  unfolding mat_max_comm[OF zero_carrier_mat A]
  by (rule mat_max_id[OF mat0_leastIII[OF A] A], simp)

lemma mat_arc_pos_one:
  assumes n0: "n > 0"
  shows "mat_arc_posI arc_pos (1m n)"
  unfolding mat_arc_posI_def
  unfolding arc_pos_one index_one_mat(1)[OF n0 n0]
  using arc_pos_one by simp

lemma mat_arc_pos_zero:
  assumes n0: "n > 0"
  shows "¬ mat_arc_posI arc_pos (0m n n)"
  unfolding mat_arc_posI_def
  unfolding index_zero_mat(1)[OF n0 n0] using arc_pos_zero by simp

lemma mat_gt_arc_plus_mono:
  assumes gt1: "mat_gt_arc A B"
  and gt2: "mat_gt_arc C D"
  and A: "(A::'a mat)  carrier_mat nr nc"
  and B: "(B::'a mat)  carrier_mat nr nc"
  and C: "(C::'a mat)  carrier_mat nr nc"
  and D: "(D::'a mat)  carrier_mat nr nc"
  shows "mat_gt_arc (A + C) (B + D)" (is "mat_gt_arc ?AC ?BD")
proof -
  show ?thesis
  proof (rule mat_comp_allI)
    fix i j
    assume i: "i < nr" and j: "j < nc"
    hence ijC: "i < dim_row C" "j < dim_col C"
      and ijD: "i < dim_row D" "j < dim_col D"
      using C D by auto
    show "?AC $$ (i,j)  ?BD $$ (i,j)"
      unfolding index_add_mat(1)[OF ijC]
      unfolding index_add_mat(1)[OF ijD]
      using plus_gt_both_mono
      using mat_comp_allE[OF gt1 A B] mat_comp_allE[OF gt2 C D] i j by auto
  qed (insert A B C D, auto)
qed

definition vec_comp_all :: "('a  'a  bool)  'a vec  'a vec  bool"
  where "vec_comp_all r v w  i < dim_vec v. r (v $ i) (w $ i)"

lemma vec_comp_allI:
  assumes "i. i < dim_vec v  r (v $ i) (w $ i)"
  shows "vec_comp_all r v w"
  unfolding vec_comp_all_def using assms by auto

lemma vec_comp_allE:
  "vec_comp_all r v w  i < dim_vec v  r (v $ i) (w $ i)"
  unfolding vec_comp_all_def by auto

lemma scalar_prod_left_mono:
  assumes u: "u  carrier_vec n"
  and v: "v  carrier_vec n"
  and w: "w  carrier_vec n"
  and uv: "vec_comp_all gt u v"
  shows "scalar_prod u w  scalar_prod v w"
proof -
  { fix m assume "m  n"
    hence "(i<m. (u $ i) * (w $ i))  (i<m. (v $ i) * (w $ i))"
    proof (induct m)
      case 0 show ?case using zero_leastI by simp next
      case (Suc m)
        hence uv: "u $ m  v $ m"
          using vec_comp_allE[OF uv] u by auto
        show ?case
          unfolding sum.lessThan_Suc
          apply (subst plus_gt_both_mono) 
          using times_gt_left_mono Suc times_gt_left_mono[OF uv] by auto
    qed
  }
  from this[OF order.refl]
  show ?thesis
    unfolding scalar_prod_def atLeast0LessThan
    using w by auto
qed

lemma scalar_prod_right_mono:
  assumes u: "u  carrier_vec n"
  and v: "v  carrier_vec n"
  and w: "w  carrier_vec n"
  and vw: "vec_comp_all gt v w"
  shows "scalar_prod u v  scalar_prod u w"
proof -
  { fix m assume "m  n"
    hence "(i<m. (u $ i) * (v $ i))  (i<m. (u $ i) * (w $ i))"
    proof (induct m)
      case 0 show ?case using zero_leastI by simp next
      case (Suc m)
        hence vw: "v $ m  w $ m"
          using vec_comp_allE[OF vw] v by auto
        show ?case
          unfolding sum.lessThan_Suc
          apply (subst plus_gt_both_mono) 
          using times_gt_left_mono Suc times_gt_right_mono[OF vw] by auto
    qed
  }
  from this[OF order.refl]
  show ?thesis
    unfolding scalar_prod_def atLeast0LessThan
    using v w by auto
qed

lemma mat_gt_arc_mult_left_mono:
  assumes gt1: "mat_gt_arc A B"
  and A: "(A::'a mat)  carrier_mat nr n"
  and B: "(B::'a mat)  carrier_mat nr n"
  and C: "(C::'a mat)  carrier_mat n nc"
  shows "mat_gt_arc (A * C) (B * C)" (is "mat_gt_arc ?AC ?BC")
proof (rule mat_comp_allI)
  fix i j assume i: "i < nr" and j: "j < nc"
  hence iA: "i < dim_row A"
    and iB: "i < dim_row B"
    and jC: "j < dim_col C"
    using A B C by auto
  show "?AC $$ (i,j)  ?BC $$ (i,j)"
    unfolding index_mult_mat(1)[OF iA jC]
    unfolding index_mult_mat(1)[OF iB jC]
  proof(rule scalar_prod_left_mono)
    show "row A i  carrier_vec n" using A by auto
    show "row B i  carrier_vec n" using B by auto
    show "col C j  carrier_vec n" using C by auto
    show rowAB: "vec_comp_all (≻) (row A i) (row B i)"
    proof (intro vec_comp_allI)
      fix j assume j: "j < dim_vec (row A i)"
      have "A $$ (i,j)  B $$ (i,j)"
        using mat_comp_allE[OF gt1 A B i] j A by simp
      thus "row A i $ j  row B i $ j"
        using A B C i j by simp
    qed
  qed
qed (insert A B C, auto)

lemma mat_gt_arc_mult_right_mono:
  assumes gt1: "mat_gt_arc B C"
  and A: "(A::'a mat)  carrier_mat nr n"
  and B: "(B::'a mat)  carrier_mat n nc"
  and C: "(C::'a mat)  carrier_mat n nc"
  shows "mat_gt_arc (A * B) (A * C)" (is "mat_gt_arc ?AB ?AC")
proof (rule mat_comp_allI)
  fix i j assume i: "i < nr" and j: "j < nc"
  hence iA: "i < dim_row A"
    and jB: "j < dim_col B"
    and jC: "j < dim_col C"
    using A B C by auto
  show "?AB $$ (i,j)  ?AC $$ (i,j)"
    unfolding index_mult_mat(1)[OF iA jB]
    unfolding index_mult_mat(1)[OF iA jC]
  proof(rule scalar_prod_right_mono)
    show "row A i  carrier_vec n" using A by auto
    show "col B j  carrier_vec n" using B by auto
    show "col C j  carrier_vec n" using C by auto
    show rowAB: "vec_comp_all (≻) (col B j) (col C j)"
    proof (intro vec_comp_allI)
      fix i assume i: "i < dim_vec (col B j)"
      have "B $$ (i,j)  C $$ (i,j)"
        using mat_comp_allE[OF gt1 B C] i j B by simp
      thus "col B j $ i  col C j $ i"
        using A B C i j by simp
    qed
  qed
qed (insert A B C, auto)

lemma mat_arc_pos_plus:
  assumes n0: "n > 0" 
  and A: "A  carrier_mat n n"
  and B: "B  carrier_mat n n"
  and arc_pos: "mat_arc_pos A"
  shows "mat_arc_pos (A + B)"
  unfolding mat_arc_posI_def
  apply (subst index_add_mat(1))
  using arc_pos_plus[OF arc_pos[unfolded mat_arc_posI_def]]
  assms by auto

lemma scalar_prod_split_head: assumes 
  "A  carrier_mat n n" "B  carrier_mat n n" "n > 0" 
  shows "row A 0  col B 0 = A $$ (0,0) * B $$ (0,0) + (i = 1..<n. A $$ (0, i) * B $$ (i, 0))"
  unfolding scalar_prod_def
  using assms sum.atLeast_Suc_lessThan by auto


lemma mat_arc_pos_mult:
  assumes n0: "n > 0" 
  and A: "A  carrier_mat n n"
  and B: "B  carrier_mat n n"
  and apA: "mat_arc_pos A"
  and apB: "mat_arc_pos B"
  shows "mat_arc_pos (A * B)"
  unfolding mat_arc_posI_def
  apply(subst index_mult_mat(1))
proof -
  let ?prod = "row A 0  col B 0"
  let ?head = "A $$ (0,0) * B $$ (0,0)"
  let ?rest = "i = 1..<n. A $$ (0, i) * B $$ (i, 0)"
  have ap: "arc_pos ?head"
    using apA apB
    unfolding mat_arc_posI_def
    using arc_pos_mult by auto
  have split: "?prod = ?head + ?rest"
    by (rule scalar_prod_split_head[OF A B n0])
  show "arc_pos (row A 0  col B 0)"
    unfolding split
    using ap arc_pos_plus by auto
qed (insert A B n0, auto)

lemma mat_arc_pos_mat_default:
  assumes n0: "n > 0" shows "mat_arc_pos (mat_default default n)"
  unfolding mat_arc_posI_def
  unfolding mat_default_def
  unfolding index_mat(1)[OF n0 n0]
  using arc_pos_default by simp

lemma mat_not_all_ge:
  assumes n_pos: "n > 0"
  and A: "A  carrier_mat n n"
  and B: "B  carrier_mat n n"
  and apB: "mat_arc_pos B"
  shows "C. C  carrier_mat n n  mat_ge C (0m n n)  mat_arc_pos C  ¬ mat_ge A (B * C)"
proof -
  define c where "c = A $$ (0,0)"
  from apB have "arc_pos (B $$ (0,0))" unfolding mat_arc_posI_def .
  from not_all_ge[OF this, of c] obtain e where e0: "e  0" and ae: "arc_pos e"
    and nc: "¬ c  B $$ (0,0) * e" by auto
  let ?f = "λ i j. if i = 0  j = 0 then e else 0"
  let ?C = "mat n n (λ (i,j). ?f i j)"
  have C: "?C  carrier_mat n n" by auto
  have C00: "?C $$ (0,0) = e" using n_pos by auto
  show ?thesis
  proof(intro exI conjI)
    show "?C m 0m n n" 
      by (rule mat_geI[of _ n n], auto simp: ge_refl e0)
    show "mat_arc_pos ?C" 
      unfolding mat_arc_posI_def 
      unfolding C00 by (rule ae)
    let ?mult = "B * ?C"
    from n_pos obtain nn where n: "n = Suc nn" by (cases n, auto)
    have col: "col ?C 0 = vec n (?f 0)" using n_pos by auto
    let ?prod = "row B 0  col ?C 0"
    let ?head = "B $$ (0,0) * ?C $$ (0,0)"
    let ?rest = "i = 1..<n. B $$ (0, i) * ?C $$ (i, 0)"

    from n_pos B have "?mult $$ (0,0) = ?prod" by auto
    also have " = ?head + ?rest"
      by (rule scalar_prod_split_head[OF B C n_pos])
    also have "?rest = 0"
      by (rule sum.neutral, auto)
    finally have "?mult $$ (0,0) = B $$ (0,0) * e" using n_pos by simp
    with nc c_def have not_ge: "¬ A $$ (0,0)  ?mult $$ (0,0)" by simp
    show "¬ A m ?mult" 
    proof
      assume "A m ?mult"
      from mat_geD[OF this, of 0 0] A B not_ge n_pos show False by auto
    qed
  qed auto
qed

end

context SN_both_mono_ordered_semiring_1
begin

lemma mat_gt_arc_SN:
  assumes n_pos: "n > 0"
  shows "SN {(A,B)  carrier_mat n n × carrier_mat n n. mat_arc_pos B  mat_gt_arc A B}"
  (is "SN ?rel")
proof (rule ccontr)
  assume "¬ SN ?rel"
  then obtain f A where "f (0 :: nat) = A" and steps: " i. (f i, f (Suc i))  ?rel" unfolding SN_defs by blast
  hence pos: " i. arc_pos (f (Suc i) $$ (0,0))" unfolding mat_arc_posI_def by blast
  have gt: " i. f i $$ (0,0)  f (Suc i) $$ (0,0)"
  proof
    fix i
    from steps 
    have wf1: "f i  carrier_mat n n"
      and wf2: "f (Suc i)  carrier_mat n n"
      and gt: "mat_gt_arc (f i) (f (Suc i))" by auto
    show "f i $$ (0,0)   f (Suc i) $$ (0,0)"
      using mat_comp_allE[OF gt wf1 wf2]
      using index_zero_mat n_pos by force
  qed
  from pos gt SN show False unfolding SN_defs by force
qed


end

end

Theory Ring_Hom_Matrix

(*  
    Author:      René Thiemann 
                 Akihisa Yamada
    License:     BSD
*)
section ‹Matrix Conversions›

text ‹Essentially, the idea is to use the JNF results to estimate the growth rates of 
  matrices. Since the results in JNF are only applicable for real normed fields,
  we cannot directly use them for matrices over the integers or the rational numbers.
  To this end, we define a homomorphism which allows us to first convert all numbers
  to real numbers, and then do the analysis.›

theory Ring_Hom_Matrix
imports 
  Matrix
  Polynomial_Interpolation.Ring_Hom
begin

locale ord_ring_hom = idom_hom hom for 
  hom :: "'a :: linordered_idom  'b :: floor_ceiling" +
  assumes hom_le: "hom x  z  x  of_int z"

text ‹Now a class based variant especially for homomorphisms into the reals.›
class real_embedding = linordered_idom + 
  fixes real_of :: "'a  real"
  assumes
    real_add: "real_of ((x :: 'a) + y) = real_of x + real_of y" and
    real_mult: "real_of (x * y) = real_of x * real_of y" and
    real_zero: "real_of 0 = 0" and
    real_one: "real_of 1 = 1" and
    real_le: "real_of x  z  x  of_int z"

interpretation real_embedding: ord_ring_hom "(real_of :: 'a :: real_embedding  real)"
  by (unfold_locales; fact real_add real_mult real_zero real_one real_le)

instantiation real :: real_embedding
begin
definition real_of_real :: "real  real" where
  "real_of_real x = x"

instance
  by (intro_classes, auto simp: real_of_real_def, linarith)
end

instantiation int :: real_embedding
begin

definition real_of_int :: "int  real" where
  "real_of_int x = x"

instance
  by (intro_classes, auto simp: real_of_int_def, linarith)
end

lemma real_of_rat_ineq: assumes "real_of_rat x  z"
  shows "x  of_int z"
proof -
  have "z  of_int z" by linarith
  from order_trans[OF assms this] 
  have "real_of_rat x  real_of_rat (of_int z)" by auto
  thus "x  of_int z" using of_rat_less_eq by blast
qed

instantiation rat :: real_embedding
begin
definition real_of_rat :: "rat  real" where
  "real_of_rat x = of_rat x"

instance
  by (intro_classes, auto simp: real_of_rat_def of_rat_add of_rat_mult real_of_rat_ineq)
end

abbreviation mat_real ("mat") where "mat  map_mat (real_of :: 'a :: real_embedding  real)"

end

Theory Derivation_Bound

(*  
    Author:      René Thiemann 
                 Akihisa Yamada
    License:     BSD
*)
section ‹Derivation Bounds›

text ‹Starting from this point onwards we apply the results on matrices to derive
  complexity bounds in \isafor. So, here begins the connection to the definitions
  and prerequisites that have originally been defined within \isafor.

This theory contains the notion of a derivation bound.›

theory Derivation_Bound
imports
  "Abstract-Rewriting.Abstract_Rewriting"
begin

definition deriv_bound :: "'a rel  'a  nat  bool"
where
  "deriv_bound r a n  ¬ ( b. (a, b)  r ^^ Suc n)"

lemma deriv_boundI [intro?]:
  "(b m. n < m  (a, b)  r ^^ m  False)  deriv_bound r a n"
  by (auto simp: deriv_bound_def) (metis lessI relpow_Suc_I)

lemma deriv_boundE:
  assumes "deriv_bound r a n"
    and "(b m. n < m  (a, b)  r ^^ m  False)  P"
  shows "P"
  using assms(1)
  by (intro assms)
     (auto simp: deriv_bound_def relpow_add relcomp.simps dest!: less_imp_Suc_add, metis relpow_E2)

lemma deriv_bound_iff:
  "deriv_bound r a n  (b m. n < m  (a, b)  r ^^ m)"
  by (auto elim: deriv_boundE intro: deriv_boundI)

lemma deriv_bound_empty [simp]:
  "deriv_bound {} a n"
  by (simp add: deriv_bound_def)

lemma deriv_bound_mono:
  assumes "m  n" and "deriv_bound r a m"
  shows "deriv_bound r a n"
  using assms by (auto simp: deriv_bound_iff)

lemma deriv_bound_image: 
  assumes b: "deriv_bound r' (f a) n"
    and step: " a b. (a, b)  r  (f a, f b)  r'+"
  shows "deriv_bound r a n"
proof
  fix b m
  assume "(a, b)  r ^^ m"
  from relpow_image [OF step this] have "(f a, f b)  r'+ ^^ m" .
  from trancl_steps_relpow [OF subset_refl this]
    obtain k where "k  m" and "(f a, f b)  r' ^^ k" by auto
  moreover assume "n < m"
  moreover with deriv_bound_mono [OF _ b, of "m - 1"]
    have "deriv_bound r' (f a) (m - 1)" by simp
  ultimately show False using b by (simp add: deriv_bound_iff)
qed

lemma deriv_bound_subset:
  assumes "r  r'+"
    and b: "deriv_bound r' a n"
  shows "deriv_bound r a n"
  using assms by (intro deriv_bound_image [of _ "λx. x", OF b]) auto

lemma deriv_bound_SN_on:
  assumes "deriv_bound r a n"
  shows "SN_on r {a}"
proof
  fix f
  assume steps: " i. (f i, f (Suc i))  r" and "f 0  {a}"
  with assms have "(f 0, f (Suc n))  r ^^ Suc n" by (blast elim: deriv_boundE)
  moreover have "(f 0, f (Suc n))  r ^^ Suc n"
    using steps unfolding relpow_fun_conv by (intro exI [of _ f]) auto
  ultimately show False ..
qed

lemma deriv_bound_steps:
  assumes "(a, b)  r ^^ n"
    and "deriv_bound r a m"
  shows "n  m"
  using assms by (auto iff: not_less deriv_bound_iff)
end

Theory Complexity_Carrier

(*  
    Author:      René Thiemann 
                 Akihisa Yamada
    License:     BSD
*)
section ‹Complexity Carrier›

text ‹We define which properties a carrier of matrices must exhibit, so that it
  can be used for checking complexity proofs.›

theory Complexity_Carrier
imports
  "Abstract-Rewriting.SN_Order_Carrier"
  Ring_Hom_Matrix
  Derivation_Bound
  HOL.Real
begin

class large_real_ordered_semiring_1 = large_ordered_semiring_1 + real_embedding

instance real :: large_real_ordered_semiring_1 ..
instance int :: large_real_ordered_semiring_1 ..
instance rat :: large_real_ordered_semiring_1 ..

text ‹For complexity analysis, we need a bounding function which tells us how often
  one can strictly decrease a value. To this end, $\delta$-orderings are usually applied
  when working with the reals or rational numbers.›

locale complexity_one_mono_ordered_semiring_1 = one_mono_ordered_semiring_1 default gt 
  for gt :: "'a :: large_ordered_semiring_1  'a  bool" (infix "" 50) and default :: 'a + 
  fixes bound :: "'a  nat"
  assumes bound_mono: " a b. a  b  bound a  bound b"
   and bound_plus: " a b. bound (a + b)  bound a + bound b" 
   and bound_plus_of_nat: " a n. a  0  bound (a + of_nat n) = bound a + bound (of_nat n)" 
   and bound_zero[simp]: "bound 0 = 0"
   and bound_one: "bound 1  1"
   and bound: " a. deriv_bound {(a,b). b  0  a  b} a (bound a)"
begin


lemma bound_linear: " c.  n. bound (of_nat n)  c * n"
proof (rule exI[of _ "bound 1"], intro allI)
  fix n
  show "bound (of_nat n)  bound 1 * n"
  proof (induct n)
    case (Suc n)
    have "bound (of_nat (Suc n)) = bound (1 + of_nat n)" by simp
    also have "...  bound 1 + bound (of_nat n)"
      by (rule bound_plus)
    also have "...  bound 1 + bound 1 * n"
      using Suc by auto
    finally show ?case by auto
  qed simp
qed

lemma bound_of_nat_times: "bound (of_nat n * v)  n * bound v"
proof (induct n)
  case (Suc n)
  have "bound (of_nat (Suc n) * v) = bound (v + of_nat n * v)" by (simp add: field_simps)
  also have "  bound v + bound (of_nat n * v)" by (rule bound_plus)
  also have "  bound v + n * bound v" using Suc by auto
  finally show ?case by simp 
qed simp

lemma bound_mult_of_nat: "bound (a * of_nat n)  bound a * bound (of_nat n)"
proof (induct n)
  case (Suc n)
  have "bound (a * of_nat (Suc n)) = bound (a + a * of_nat n)" by (simp add: field_simps)
  also have "...  bound a + bound (a * of_nat n)"
    by (rule bound_plus)
  also have "...  bound a + bound a * bound (of_nat n)" using Suc by auto
  also have "... = bound a * (1 + bound (of_nat n))" by (simp add: field_simps)
  also have "...  bound a * (bound (1 + of_nat n))"
  proof (rule mult_le_mono2)
    show "1 + bound(of_nat n)  bound (1 + of_nat n)" using bound_one
    using bound_plus
      unfolding bound_plus_of_nat[OF one_ge_zero] by simp
  qed
  finally show ?case by simp
qed simp

lemma bound_pow_of_nat: "bound (a * of_nat n ^ deg)  bound a * of_nat n ^ deg" 
proof (induct deg)
  case (Suc deg)
  have "bound (a * of_nat n ^ Suc deg) =  bound (of_nat n * (a * of_nat n ^ deg))"
    by (simp add: field_simps)
  also have "  n * bound (a * of_nat n ^ deg)"
    by (rule bound_of_nat_times)
  also have "  n * (bound a * of_nat n ^ deg)"
    using Suc by auto
  finally show ?case by (simp add: field_simps)
qed simp
end

end

Theory Show_Arctic

(*  
    Author:      René Thiemann 
                 Akihisa Yamada
    License:     BSD
*)
section ‹Converting Arctic Numbers to Strings›

text ‹We just instantiate arctic numbers in the show-class.›

theory Show_Arctic
imports 
  "Abstract-Rewriting.SN_Order_Carrier"
  Show.Show_Instances
begin

instantiation arctic :: "show"
begin

fun shows_arctic :: "arctic  shows"
where
  "shows_arctic (Num_arc i) = shows i" |
  "shows_arctic (MinInfty) = shows ''-inf''"

definition "shows_prec (p :: nat) ai = shows_arctic ai"

lemma shows_prec_artic_append [show_law_simps]:
  "shows_prec p (a :: arctic) (r @ s) = shows_prec p a r @ s"
  by (cases a) (auto simp: shows_prec_arctic_def show_law_simps)

definition "shows_list (as :: arctic list) = showsp_list shows_prec 0 as"

instance
  by standard (simp_all add: shows_list_arctic_def show_law_simps)

end

instantiation arctic_delta :: ("show") "show"
begin

fun shows_arctic_delta :: "'a arctic_delta  shows"
where
  "shows_arctic_delta (Num_arc_delta i) = shows i" |
  "shows_arctic_delta (MinInfty_delta)  = shows ''-inf''"

definition "shows_prec (d :: nat) ari = shows_arctic_delta ari"

lemma shows_prec_arctic_delta_append [show_law_simps]:
  "shows_prec d (a :: 'a arctic_delta) (r @ s) = shows_prec d a r @ s"
  by (cases a) (auto simp: shows_prec_arctic_delta_def show_law_simps)

definition "shows_list (ps :: 'a arctic_delta list) = showsp_list shows_prec 0 ps"

instance
  by standard (simp_all add: shows_list_arctic_delta_def show_law_simps)

end

end

Theory Matrix_Complexity

(*  
    Author:      René Thiemann 
                 Akihisa Yamada
    License:     BSD
*)
section ‹Application: Complexity of Matrix Orderings›

text ‹In this theory we provide various carriers which can be used for matrix interpretations.›

theory Matrix_Complexity 
imports
  Matrix_Comparison
  Complexity_Carrier
  Show_Arctic
begin

subsection ‹Locales for Carriers of Matrix Interpretations and Polynomial Orders›

locale matrix_carrier = SN_one_mono_ordered_semiring_1 d gt
  for gt :: "'a :: {show,ordered_semiring_1}  'a  bool" (infix "" 50) and d :: "'a"

locale mono_matrix_carrier = complexity_one_mono_ordered_semiring_1 gt d bound
  for gt :: "'a :: {show,large_real_ordered_semiring_1}  'a  bool" (infix "" 50) and d :: 'a
  and bound :: "'a  nat" 
+ fixes mono :: "'a  bool"
  assumes mono: " x y z. mono x  y  z  x  0  x * y  x * z"


text ‹The weak version make comparison with $>$ and then synthesize a suitable
  $\delta$-ordering by choosing the least difference in the finite set of
  comparisons.›

locale weak_complexity_linear_poly_order_carrier = 
  fixes weak_gt :: "'a :: {large_real_ordered_semiring_1,show}  'a  bool"
   and  default :: "'a"
   and  mono :: "'a  bool"
  assumes weak_gt_mono: " x y. (x,y)  set xys  weak_gt x y 
    gt bound. mono_matrix_carrier gt default bound mono  ( x y. (x,y)  set xys  gt x y)"
begin

abbreviation weak_mat_gt :: "nat  'a mat  'a mat  bool"
where "weak_mat_gt  mat_gt weak_gt"

lemma weak_mat_gt_mono: assumes sd_n: "sd  n" and
    orient: " A B. A  carrier_mat n n  B  carrier_mat n n  (A,B)  set ABs  weak_mat_gt sd A B"
   shows " gt bound. mono_matrix_carrier gt default bound mono 
    ( A B. A  carrier_mat n n  B  carrier_mat n n  (A, B)  set ABs  mat_gt gt sd A B)"
proof -
  let ?n = "[0 ..< n]"
  let ?m1x = "[ A $$ (i,j) . A <- map fst ABs, i <- ?n, j <- ?n]"
  let ?m2y = "[ B $$ (i,j) . B <- map snd ABs, i <- ?n, j <- ?n]"
  let ?pairs = "concat (map (λ x. map (λ y. (x,y)) ?m2y) ?m1x)"
  let ?strict = "filter (λ (x,y). weak_gt x y) ?pairs"
  have " x y. (x,y)  set ?strict  weak_gt x y" by auto
  from weak_gt_mono[OF this] obtain gt bound where order: "mono_matrix_carrier gt default bound mono" 
    and orient2: " x y. (x, y)  set ?strict  gt x y" by auto
  show ?thesis
  proof (intro exI allI conjI impI, rule order)
    fix A B
    assume A: "A  carrier_mat n n" and B: "B  carrier_mat n n"
      and AB: "(A, B)  set ABs"          
    from orient[OF this] have "mat_gt weak_gt sd A B" by auto
    from mat_gtD[OF this] obtain i j where
      ge: "A m B" and ij: "i < sd" "j < sd" and wgt: "weak_gt (A $$ (i,j)) (B $$ (i,j))"
      by auto
    from ij sd  n have ij': "i < n" "j < n" by auto
    have gt: "gt (A $$ (i,j)) (B $$ (i,j))"
      by (rule orient2, insert ij' AB wgt, force)
    show "mat_gt gt sd A B" using ij gt ge by auto
  qed
qed
end

sublocale mono_matrix_carrier  SN_strict_mono_ordered_semiring_1 d gt mono
proof
  show "SN {(x,y). y  0  x  y}" 
    unfolding SN_def
    by (intro allI deriv_bound_SN_on[OF bound])
qed (rule mono)

sublocale mono_matrix_carrier  matrix_carrier ..

subsection ‹The Integers as Carrier›

lemma int_complexity:
  "mono_matrix_carrier ((>) :: int  int  bool) 1 nat int_mono"
proof (unfold_locales)
  fix x
  let ?R = "{(x, y). 0  (y :: int)  y < x}" 
  show "deriv_bound ?R x (nat x)"
    unfolding deriv_bound_def
  proof
    assume "( y. (x,y)  ?R ^^ Suc (nat x))"
    then obtain y where xy: "(x,y)  ?R ^^ Suc (nat x)" ..
    from xy have y: "0  y" by auto
    obtain n where n: "n = Suc (nat x)" by auto
    from xy[unfolded n[symmetric]]
    have "x  y + int n"
    proof (induct n arbitrary: x y)
      case 0 thus ?case by auto
    next
      case (Suc n)
      from Suc(2) obtain z where xz: "(x,z)  ?R ^^ n" and zy: "(z,y)  ?R"
        by auto
      from Suc(1)[OF xz] have le: "z + int n  x" .
      from zy have le2: "y + 1  z" by simp
      with le show ?case by auto
    qed
    with y have nx: "int n  x" by simp
    from nx have x0: "x  0" by simp
    with nx n
    show False by simp
  qed      
qed (insert int_SN.mono, auto)

lemma int_weak_complexity:
  "weak_complexity_linear_poly_order_carrier (>) 1 int_mono"
  by (unfold_locales, intro exI[of _ "(>)"] exI[of _ nat] conjI, rule int_complexity, auto)

subsection ‹The Rational and Real Numbers as Carrier›

definition delta_bound :: "'a :: floor_ceiling  'a  nat"
where
  "delta_bound d x = nat (ceiling (x * of_int (ceiling (1 / d))))"

lemma delta_complexity:
  assumes d0: "d > 0" and d1: "d  def" 
  shows "mono_matrix_carrier (delta_gt d) def (delta_bound d) delta_mono"
proof -
  from d0 have d00: "0  d" by simp
  define N where "N = ceiling (1 / d)"
  let ?N = "of_int N :: 'a"
  from d0 have "1 / d > 0" by (auto simp: field_simps)
  with ceiling_correct[of "1 / d"] have Nd: "1 / d  ?N" and N: "N > 0" unfolding N_def by auto
  let ?nat = "λ x. nat (ceiling (x * ?N))"
  let ?gt = "delta_gt d"
  have nn: "delta_bound d = ?nat" unfolding fun_eq_iff N_def by (simp add: delta_bound_def)
  from delta_interpretation[OF d0 d1]
  interpret SN_strict_mono_ordered_semiring_1 "def" ?gt delta_mono .
  show ?thesis unfolding nn
  proof(unfold_locales)
    show "?nat 0 = 0" by auto
  next
    fix x y :: 'a
    assume xy: "x  y"
    show "?nat x  ?nat y" 
      by (rule nat_mono, rule ceiling_mono, insert xy N, auto simp: field_simps)
  next
    have "1  nat 1" by simp
    also have "...  ?nat 1"
    proof (rule nat_mono)
      have "1 = ceiling (1 :: rat)" by simp
      also have "...  ceiling (1 * ?N)" using N by simp
      finally show "1  ceiling (1 * ?N)" .
    qed
    finally show "1  ?nat 1" .
  next
    fix x y :: 'a
    have "ceiling ((x + y) * ?N) = ceiling (x * ?N + y * ?N)" by (simp add: field_simps)
    also have "...  ceiling (x * ?N) + ceiling (y * ?N)" by (rule ceiling_add_le)
    finally show "?nat (x + y)  ?nat x + ?nat y" by auto
  next
    fix x :: 'a and n :: nat
    assume x: "0  x" 
    interpret mono_matrix_carrier "(>)" 1 nat int_mono by (rule int_complexity)
    have "?nat (x + of_nat n) = nat (ceiling (x * ?N + of_nat n * ?N))" 
      by (simp add: field_simps)
    also have id: "of_nat n * ?N = of_int (of_nat (n * nat N))" using N by (simp add: field_simps)
    also have "ceiling (x * ?N + of_int (of_nat (n * nat N))) = ceiling (x * ?N) + of_nat (n * nat N)" unfolding ceiling_add_of_int ..
    also have "nat (ceiling (x * ?N) + of_nat (n * nat N)) = ?nat x + nat (int (n * nat N))"
    proof (rule bound_plus_of_nat)
      have "x * ?N  0" 
        by (rule mult_nonneg_nonneg, insert x N, auto)
      thus "ceiling (x * ?N)  0" by auto
    qed 
    also have "(nat (int (n * nat N))) = n * nat N" by presburger
    also have "n * nat N = ?nat (of_nat n)" using N by (metis id ceiling_of_int nat_int)
    finally
    show "?nat (x + of_nat n) = ?nat x + ?nat (of_nat n)" .
  next
    fix x y z :: 'a
    assume *: "delta_mono x" "delta_gt d y z" and x: "0  x"
    from mono[OF * x]
    show "delta_gt d (x * y) (x * z)" .
  next
    fix x :: 'a
    let ?R = "{(x,y). 0  y  ?gt x y}"
    show "deriv_bound ?R x (?nat x)" unfolding deriv_bound_def
    proof
      assume "( y. (x,y)  ?R ^^ Suc (?nat x))"
      then obtain y where xy: "(x,y)  ?R ^^ Suc (?nat x)" ..
      from xy have y: "0  y" by auto
      obtain n where n: "n = Suc (?nat x)" by auto
      from xy[unfolded n[symmetric]]
      have "x  y + d * of_nat n"
      proof (induct n arbitrary: x y)
        case 0 thus ?case by auto
      next
        case (Suc n)
        from Suc(2) obtain z where xz: "(x,z)  ?R ^^ n" and zy: "(z,y)  ?R"
          by auto
        from Suc(1)[OF xz] have le: "z + d * of_nat n  x" .
        from zy[unfolded delta_gt_def] have le2: "y + d  z" by simp
        with le show ?case by (auto simp: field_simps)
      qed
      with y have nx: "d * of_nat n  x" by simp
      have "0  d * of_nat n" by (rule mult_nonneg_nonneg, insert d00, auto)
      with nx have x0: "x  0" by auto
      have xd0: "0  x / d"
        by (rule divide_nonneg_pos[OF x0 d0])
      from nx[unfolded n]      
      have "d + d * of_nat (?nat x)  x" by (simp add: field_simps)
      with d0 have less: "d * of_nat (?nat x) < x" by simp
      from Nd d0 have "1  d * ?N" by (auto simp: field_simps)
      from mult_left_mono[OF this x0]
      have "x  d * (x * ?N)" by (simp add: ac_simps)
      also have "  d * of_nat (?nat x)"
      proof (rule mult_left_mono[OF _ d00])
        show "x * ?N  of_nat (nat x * ?N)" using x0 ceiling_correct[of "x * ?N"] 
          by (metis int_nat_eq le_cases of_int_0_le_iff of_int_of_nat_eq order_trans)
      qed
      also have " < x" using less .
      finally show False by simp
    qed
  qed 
qed


lemma delta_weak_complexity_carrier:
  assumes d0: "def > 0" 
  shows "weak_complexity_linear_poly_order_carrier (>) def delta_mono"
proof
  fix xys :: "('a × 'a) list"
  assume ass: "x y. (x, y)  set xys  y < x"
  let ?cs = "map (λ (x,y). x - y) xys"
  let ?ds = "def # ?cs"
  define d where "d = Min (set ?ds)"
  have d: "d  def" and dcs: " x. x  set ?cs  d  x" unfolding d_def by auto
  have "d  set ?ds" unfolding d_def by (rule Min_in, auto)
  hence "d = def  d  set ?cs" by auto
  hence d0: "d > 0"
    by (cases, insert d0 ass, auto simp: field_simps)
  show "gt bound. mono_matrix_carrier gt def bound delta_mono  (x y. (x, y)  set xys  gt x y)"
    by (intro exI conjI, rule delta_complexity[OF d0 d], insert dcs, force simp: delta_gt_def)
qed

subsection ‹The Arctic Numbers as Carrier›

lemma arctic_delta_weak_carrier:
  "weak_SN_both_mono_ordered_semiring_1 weak_gt_arctic_delta 1 pos_arctic_delta" ..

lemma arctic_weak_carrier:
  "weak_SN_both_mono_ordered_semiring_1 (>) 1 pos_arctic"
proof -
  have SN: "SN_both_mono_ordered_semiring_1 1 (>) pos_arctic" ..
  show ?thesis
    by (unfold_locales, intro conjI exI, rule SN, auto)
qed

end

Theory Matrix_Kernel

(*  
    Author:      René Thiemann 
                 Akihisa Yamada
    License:     BSD
*)
section ‹Matrix Kernel›

text ‹We define the kernel of a matrix $A$ and prove the following properties.

\begin{itemize}
\item The kernel stays invariant when multiplying $A$ with an invertible matrix from the left.
\item The dimension of the kernel stays invariant when 
  multiplying $A$ with an invertible matrix from the right.
\item The function find-base-vectors returns a basis of the kernel if $A$ is in row-echelon form.
\item The dimension of the kernel of a block-diagonal matrix is the sum of the dimensions of
  the kernels of the blocks.
\item There is an executable algorithm which computes the dimension of the kernel of a matrix
  (which just invokes Gauss-Jordan and then counts the number of pivot elements).
\end{itemize}
›

theory Matrix_Kernel
imports 
  VS_Connect
  Missing_VectorSpace
  Determinant
begin

hide_const real_vector.span
hide_const (open) Real_Vector_Spaces.span
hide_const real_vector.dim
hide_const (open) Real_Vector_Spaces.dim

definition mat_kernel :: "'a :: comm_ring_1 mat  'a vec set" where
  "mat_kernel A = { v . v  carrier_vec (dim_col A)  A *v v = 0v (dim_row A)}"

lemma mat_kernelI: assumes "A  carrier_mat nr nc" "v  carrier_vec nc" "A *v v = 0v nr"
  shows "v  mat_kernel A"
  using assms unfolding mat_kernel_def by auto

lemma mat_kernelD: assumes "A  carrier_mat nr nc" "v  mat_kernel A"
  shows "v  carrier_vec nc" "A *v v = 0v nr"
  using assms unfolding mat_kernel_def by auto

lemma mat_kernel: assumes "A  carrier_mat nr nc" 
  shows "mat_kernel A = {v. v  carrier_vec nc  A *v v = 0v nr}"
  unfolding mat_kernel_def using assms by auto

lemma mat_kernel_carrier:
  assumes "A  carrier_mat nr nc" shows "mat_kernel A  carrier_vec nc"
  using assms mat_kernel by auto

lemma mat_kernel_mult_subset: assumes A: "A  carrier_mat nr nc"
  and B: "B  carrier_mat n nr"
  shows "mat_kernel A  mat_kernel (B * A)"
proof -
  from A B have BA: "B * A  carrier_mat n nc" by auto
  show ?thesis unfolding mat_kernel[OF BA] mat_kernel[OF A] using A B by auto
qed

lemma mat_kernel_smult: assumes A: "A  carrier_mat nr nc"
  and v: "v  mat_kernel A"
  shows "a v v   mat_kernel A"
proof -
  from mat_kernelD[OF A v] have v: "v  carrier_vec nc"
    and z: "A *v v = 0v nr" by auto
  from arg_cong[OF z, of "λ v. a v v"] v 
  have "a v (A *v v) = 0v nr" by auto
  also have "a v (A *v v) = A *v (a v v)" using A v by auto
  finally show ?thesis using v A
    by (intro mat_kernelI, auto)
qed

lemma mat_kernel_mult_eq: assumes A: "A  carrier_mat nr nc"
  and B: "B  carrier_mat nr nr"
  and C: "C  carrier_mat nr nr"
  and inv: "C * B = 1m nr"
  shows "mat_kernel (B * A) = mat_kernel A"
proof 
  from B A have BA: "B * A  carrier_mat nr nc" by auto
  show "mat_kernel A  mat_kernel (B * A)" by (rule mat_kernel_mult_subset[OF A B])
  {
    fix v
    assume v: "v  mat_kernel (B * A)"
    from mat_kernelD[OF BA this] have v: "v  carrier_vec nc" and z: "B * A *v v = 0v nr" by auto
    from arg_cong[OF z, of "λ v. C *v v"] 
    have "C *v (B * A *v v) = 0v nr" using C v by auto
    also have "C *v (B * A *v v) = ((C * B) * A) *v v" 
      unfolding assoc_mult_mat_vec[symmetric, OF C BA v]    
      unfolding assoc_mult_mat[OF C B A] by simp
    also have " = A *v v" unfolding inv using A v by auto
    finally have "v  mat_kernel A"
      by (intro mat_kernelI[OF A v])
  }
  thus "mat_kernel (B * A)  mat_kernel A" by auto
qed

locale kernel =
  fixes nr :: nat
    and nc :: nat
    and A :: "'a :: field mat"
  assumes A: "A  carrier_mat nr nc"
begin

sublocale NC: vec_space "TYPE('a)" nc .

abbreviation "VK  NC.Vcarrier := mat_kernel A"

sublocale Ker: vectorspace class_ring VK 
  rewrites "carrier VK = mat_kernel A"
    and [simp]: "add VK = (+)"
    and [simp]: "zero VK = 0v nc"
    and [simp]: "module.smult VK = (⋅v)"
    and "carrier class_ring = UNIV"
    and "monoid.mult class_ring = (*)"
    and "add class_ring = (+)"
    and "one class_ring = 1"
    and "zero class_ring = 0"
    and "a_inv (class_ring :: 'a ring) = uminus"
    and "a_minus (class_ring :: 'a ring) = minus"
    and "pow (class_ring :: 'a ring) = (^)"
    and "finsum (class_ring :: 'a ring) = sum"
    and "finprod (class_ring :: 'a ring) = prod"
    and "m_inv (class_ring :: 'a ring) x = (if x = 0 then div0 else inverse x)"
  apply (intro vectorspace.intro)
  apply (rule NC.submodule_is_module)
  apply (unfold_locales)
  by (insert A mult_add_distrib_mat_vec[OF A] mult_mat_vec[OF A] mat_kernel[OF A], auto simp: class_ring_simps)

abbreviation "basis  Ker.basis"
abbreviation "span  Ker.span"
abbreviation "lincomb  Ker.lincomb"
abbreviation "dim  Ker.dim"
abbreviation "lin_dep  Ker.lin_dep"
abbreviation "lin_indpt  Ker.lin_indpt"
abbreviation "gen_set  Ker.gen_set"

lemma finsum_same:
  assumes "f : S  mat_kernel A"
  shows "finsum VK f S = finsum NC.V f S"
  using assms
proof (induct S rule: infinite_finite_induct)
  case (insert s S)
    hence base: "finite S" "s  S"
      and f_VK: "f : S  mat_kernel A" "f s : mat_kernel A" by auto
    hence f_NC: "f : S  carrier_vec nc" "f s : carrier_vec nc" using mat_kernel[OF A] by auto
    have IH: "finsum VK f S = finsum NC.V f S" using insert f_VK by auto
    thus ?case
      unfolding NC.M.finsum_insert[OF base f_NC]
      unfolding Ker.finsum_insert[OF base f_VK]
      by simp
qed auto

lemma lincomb_same:
  assumes S_kernel: "S  mat_kernel A"
  shows "lincomb a S = NC.lincomb a S"
  unfolding Ker.lincomb_def
  unfolding NC.lincomb_def
  apply(subst finsum_same)
  using S_kernel Ker.smult_closed[unfolded module_vec_simps class_ring_simps] by auto

lemma span_same:
  assumes S_kernel: "S  mat_kernel A"
  shows "span S = NC.span S"
proof (rule;rule)
  fix v assume L: "v : span S" show "v : NC.span S"
  proof -
    obtain a U where know: "finite U" "U  S" "a : U  UNIV" "v = lincomb a U"
      using L unfolding Ker.span_def by auto
    hence v: "v = NC.lincomb a U" using lincomb_same S_kernel by auto
    show ?thesis
      unfolding NC.span_def by (rule,intro exI conjI;fact)
  qed
  next fix v assume R: "v : NC.span S" show "v : span S"
  proof -
    obtain a U where know: "finite U" "U  S" "v = NC.lincomb a U"
      using R unfolding NC.span_def by auto
    hence v: "v = lincomb a U" using lincomb_same S_kernel by auto
    show ?thesis unfolding Ker.span_def by (rule, intro exI conjI, insert v know, auto)
  qed
qed

lemma lindep_same:
  assumes S_kernel: "S  mat_kernel A"
  shows "Ker.lin_dep S = NC.lin_dep S"
proof
  note [simp] = module_vec_simps class_ring_simps
  { assume L: "Ker.lin_dep S"
    then obtain v a U
    where finU: "finite U" and US: "U  S"
      and lc: "lincomb a U = 0v nc"
      and vU: "v  U"
      and av0: "a v  0"
      unfolding Ker.lin_dep_def by auto
    have lc': "NC.lincomb a U = 0v nc"
      using lc lincomb_same US S_kernel by auto
    show "NC.lin_dep S" unfolding NC.lin_dep_def
      by (intro exI conjI, insert finU US lc' vU av0, auto)
  }
  assume R: "NC.lin_dep S"
  then obtain v a U
  where finU: "finite U" and US: "U  S"
    and lc: "NC.lincomb a U = 0v nc"
    and vU: "v : U"
    and av0: "a v  0"
    unfolding NC.lin_dep_def by auto
  have lc': "lincomb a U = zero VK"
    using lc lincomb_same US S_kernel by auto
  show "Ker.lin_dep S" unfolding Ker.lin_dep_def
    by (intro exI conjI,insert finU US lc' vU av0, auto)
qed

lemma lincomb_index:
  assumes i: "i < nc"
    and Xk: "X  mat_kernel A"
  shows "lincomb a X $ i = sum (λx. a x * x $ i) X"
proof -
  have X: "X  carrier_vec nc" using Xk mat_kernel_def A by auto
  show ?thesis
    using vec_space.lincomb_index[OF i X]
    using lincomb_same[OF Xk] by auto
qed

end

lemma find_base_vectors: assumes ref: "row_echelon_form A" 
  and A: "A  carrier_mat nr nc" shows
  "set (find_base_vectors A)  mat_kernel A"
  "0v nc  set (find_base_vectors A)"
  "kernel.basis nc A (set (find_base_vectors A))"
  "card (set (find_base_vectors A)) = nc - card { i. i < nr  row A i  0v nc}"
  "length (pivot_positions A) = card { i. i < nr  row A i  0v nc}"
  "kernel.dim nc A = nc - card { i. i < nr  row A i  0v nc}"
proof -
  note non_pivot_base = non_pivot_base[OF ref A]
  let ?B = "set (find_base_vectors A)"
  let ?pp = "set (pivot_positions A)"
  from A have dim: "dim_row A = nr" "dim_col A = nc" by auto
  from ref[unfolded row_echelon_form_def] obtain p 
  where pivot: "pivot_fun A p nc" using dim by auto
  note piv = pivot_funD[OF dim(1) pivot]
  {
    fix v
    assume "v  ?B"
    from this[unfolded find_base_vectors_def Let_def dim]
      obtain c where c: "c < nc" "c  snd ` ?pp"
      and res: "v = non_pivot_base A (pivot_positions A) c" by auto
    from non_pivot_base[OF c, folded res] c
    have "v  mat_kernel A" "v  0v nc" 
      by (intro mat_kernelI[OF A], auto)
  }
  thus sub: "?B  mat_kernel A" and
    "0v nc  ?B" by auto
  {
    fix j j'
    assume j: "j < nc" "j  snd ` ?pp" and j': "j' < nc" "j'  snd ` ?pp" and neq: "j'  j"
    from non_pivot_base(2)[OF j] non_pivot_base(4)[OF j' j neq]
    have "non_pivot_base A (pivot_positions A) j  non_pivot_base A (pivot_positions A) j'" by auto
  }
  hence inj: "inj_on (non_pivot_base A (pivot_positions A))
     (set [j[0..<nc] . j  snd ` ?pp])" unfolding inj_on_def by auto
    note pp = pivot_positions[OF A pivot]
  have lc: "length (pivot_positions A) = card (snd ` ?pp)"
    using distinct_card[OF pp(3)] by auto
  show card: "card ?B = nc - card { i. i < nr  row A i  0v nc}"
    "length (pivot_positions A) = card { i. i < nr  row A i  0v nc}"
    unfolding find_base_vectors_def Let_def dim set_map  card_image[OF inj] pp(4)[symmetric]
    unfolding pp(1) lc
  proof -
    have "nc - card (snd ` {(i, p i) |i. i < nr  p i  nc})
      = card {0 ..< nc} - card (snd ` {(i, p i) |i. i < nr  p i  nc})" by auto
    also have " = card ({0 ..< nc} - snd ` {(i, p i) |i. i < nr  p i  nc})"
      by (rule card_Diff_subset[symmetric], insert piv(1), force+)
    also have "{0 ..< nc} - snd ` {(i, p i) |i. i < nr  p i  nc} = (set [j[0..<nc] . j  snd ` {(i, p i) |i. i < nr  p i  nc}])"
      by auto
    finally show "card (set [j[0..<nc] . j  snd ` {(i, p i) |i. i < nr  p i  nc}]) =
      nc - card (snd ` {(i, p i) |i. i < nr  p i  nc})" by simp
  qed auto
  interpret kernel nr nc A by (unfold_locales, rule A)
  show basis: "basis ?B"
    unfolding Ker.basis_def
  proof (intro conjI)
    show "span ?B = mat_kernel A"
    proof
      show "span ?B  mat_kernel A"
        using sub by (rule Ker.span_is_subset2)
      show "mat_kernel A  Ker.span ?B"
      proof
        fix v
        assume "v  mat_kernel A" 
        from mat_kernelD[OF A this]
        have v: "v  carrier_vec nc" and Av: "A *v v = 0v nr" by auto
        let ?bi = "non_pivot_base A (pivot_positions A)"
        let ?ran = "set [j[0..<nc] . j  snd ` ?pp]"
        let ?ran' = "set [j[0..<nc] . j  snd ` ?pp]"
        have dimv: "dim_vec v = nc" using v by auto
        define I where "I = (λ b. SOME i. i  ?ran  ?bi i = b)"
        {
          fix j
          assume j: "j  ?ran"
          hence " i. i  ?ran  ?bi i = ?bi j" unfolding find_base_vectors_def Let_def dim by auto
          from someI_ex[OF this] have I: "I (?bi j)  ?ran" and id: "?bi (I (?bi j)) = ?bi j" unfolding I_def by blast+
          from inj_onD[OF inj id I j] have "I (?bi j) = j" .
        } note I = this        
        define a where "a = (λ b. v $ (I b))"
        from Ker.lincomb_closed[OF sub] have diml: "dim_vec (lincomb a ?B) = nc"
          unfolding mat_kernel_def using dim lincomb_same by auto
        have "v = lincomb a ?B"
        proof (rule eq_vecI; unfold diml dimv)
          fix j
          assume j: "j < nc"
          have "Ker.lincomb a ?B $ j = (b ?B. a b * b $ j)" by (rule lincomb_index[OF j sub])
          also have " = ( i ?ran. v $ i * ?bi i $ j)"
          proof (subst sum.reindex_cong[OF inj])
            show "?B = ?bi ` ?ran"  unfolding find_base_vectors_def Let_def dim by auto
            fix i
            assume "i  ?ran"
            hence "I (?bi i) = i" by (rule I)
            hence "a (?bi i) = v $ i" unfolding a_def by simp
            thus "a (?bi i) * ?bi i $ j = v $ i * ?bi i $ j" by simp
          qed auto
          also have " = v $ j"
          proof (cases "j  ?ran")
            case True
            hence nmem: "j  snd ` set (pivot_positions A)" by auto 
            note npb = non_pivot_base[OF j nmem]
            have "( i ?ran. v $ i * (?bi i) $ j) =
              v $ j * ?bi j $ j + ( i ?ran - {j}. v $ i * ?bi i $ j)"
              by (subst sum.remove[OF _ True], auto)
            also have "?bi j $ j = 1" using npb by simp
            also have "( i  ?ran - {j}. v $ i * ?bi i $ j) = 0"
              using insert non_pivot_base(4)[OF _ _ j nmem] by (intro sum.neutral, auto)
            finally show ?thesis by simp
          next
            case False
            with j have jpp: "j  snd ` ?pp" by auto
            with j pp obtain i where i: "i < nr" and ji: "j = p i" and pi: "p i < nc" by auto
            from arg_cong[OF Av, of "λ u. u $ i"] i A
            have "v $ j = v $ j - row A i  v" by auto
            also have "row A i  v = ( j = 0 ..< nc. A $$ (i,j) * v $ j)" unfolding scalar_prod_def using v A i by auto
            also have " = ( j  ?ran. A $$ (i,j) * v $ j) +  ( j  ?ran'. A $$ (i,j) * v $ j)"
              by (subst sum.union_disjoint[symmetric], auto intro: sum.cong)
            also have "( j  ?ran'. A $$ (i,j) * v $ j) =
              A $$ (i,p i) * v $ j + ( j  ?ran' - {p i}. A $$ (i,j) * v $ j)"
              using jpp by (subst sum.remove, auto simp: ji i pi)
            also have "A $$ (i, p i) = 1" using piv(4)[OF i] pi ji by auto
            also have "( j  ?ran' - {p i}. A $$ (i,j) * v $ j) = 0"
            proof (rule sum.neutral, intro ballI)
              fix j'
              assume "j'  ?ran' - {p i}"
              then obtain i' where i': "i' < nr" and j': "j' = p i'" and pi': "p i'  nc" and neq: "p i'  p i"
                unfolding pp by auto
              from pi' piv[OF i'] have pi': "p i' < nc" by auto
              from pp pi' neq j i' i have "i  i'" by auto
              from piv(5)[OF i' pi' i this]
              show "A $$ (i,j') * v $ j' = 0" unfolding j' by simp
            qed
            also have "( j  ?ran. A $$ (i,j) * v $ j) = - ( j  ?ran. v $ j * - A $$ (i,j))" 
              unfolding sum_negf[symmetric] by (rule sum.cong, auto)
            finally have vj: "v $ j = ( j  ?ran. v $ j * - A $$ (i,j))" by simp
            show ?thesis unfolding vj j
            proof (rule sum.cong[OF refl])
              fix j'
              assume j': "j'  ?ran"
              from jpp j' have jj': "j  j'" by auto
              let ?map = "map prod.swap (pivot_positions A)"
              from ji i j have "(i,j)  set (pivot_positions A)" unfolding pp by auto
              hence mem: "(j,i)  set ?map" by auto
              from pp have "distinct (map fst ?map)" unfolding map_map o_def prod.swap_def fst_conv by auto
              from map_of_is_SomeI[OF this mem] have "map_of ?map j = Some i" by auto
              hence "?bi j' $ j = - A $$ (i, j')" 
                unfolding non_pivot_base_def Let_def dim using j jj' by auto
              thus "v $ j' * ?bi j' $ j = v $ j' * - A $$ (i,j')" by simp
            qed
          qed
          finally show "v $ j = lincomb a ?B $ j" ..
        qed auto
        thus "v  span ?B" unfolding Ker.span_def by auto
      qed
    qed
    show "?B  mat_kernel A" by (rule sub)
    {
      fix a v
      assume lc: "lincomb a ?B = 0v nc" and vB: "v  ?B"
      from vB[unfolded find_base_vectors_def Let_def dim]
        obtain j where j: "j < nc" "j  snd ` ?pp" and v: "v = non_pivot_base A (pivot_positions A) j"
        by auto         
      from arg_cong[OF lc, of "λ v. v $ j"] j
      have "0 = lincomb a ?B $ j" by auto
      also have " = (v?B. a v * v $ j)" 
        by (subst lincomb_index[OF j(1) sub], simp)
      also have " = a v * v $ j + (w?B - {v}. a w * w $ j)"
        by (subst sum.remove[OF _ vB], auto)
      also have "a v * v $ j = a v" using non_pivot_base[OF j, folded v] by simp
      also have "(w?B - {v}. a w * w $ j) = 0"
      proof (rule sum.neutral, intro ballI)
        fix w
        assume wB: "w  ?B - {v}"
        from this[unfolded find_base_vectors_def Let_def dim]
        obtain j' where j': "j' < nc" "j'  snd ` ?pp" and w: "w = non_pivot_base A (pivot_positions A) j'"
          by auto    
        with wB v have "j'  j" by auto
        from non_pivot_base(4)[OF j' j this]
        show "a w * w $ j = 0" unfolding w by simp
      qed
      finally have "a v = 0" by simp
    }
    thus "¬ lin_dep ?B"
      by (intro Ker.finite_lin_indpt2[OF finite_set sub], auto simp: class_field_def)
  qed
  show "dim = nc - card { i. i < nr  row A i  0v nc}"
    using Ker.dim_basis[OF finite_set basis] card by simp
qed


definition kernel_dim :: "'a :: field mat  nat" where
  [code del]: "kernel_dim A = kernel.dim (dim_col A) A"

lemma (in kernel) kernel_dim [simp]: "kernel_dim A = dim" unfolding kernel_dim_def
  using A by simp

lemma kernel_dim_code[code]: 
  "kernel_dim A = dim_col A - length (pivot_positions (gauss_jordan_single A))"
proof -
  define nr where "nr = dim_row A" 
  define nc where "nc = dim_col A"
  let ?B = "gauss_jordan_single A"
  have A: "A  carrier_mat nr nc" unfolding nr_def nc_def by auto
  from gauss_jordan_single[OF A refl]
    obtain P Q where AB: "?B = P * A" and QP: "Q * P = 1m nr" and
    P: "P  carrier_mat nr nr" and Q: "Q  carrier_mat nr nr" and B: "?B  carrier_mat nr nc" 
    and row: "row_echelon_form ?B" by auto
  interpret K: kernel nr nc ?B
    by (unfold_locales, rule B)
  from mat_kernel_mult_eq[OF A P Q QP, folded AB]
  have "kernel_dim A = K.dim" unfolding kernel_dim_def using A by simp
  also have " = nc - length (pivot_positions ?B)" using find_base_vectors[OF row B] by auto
  also have " = dim_col A - length (pivot_positions ?B)"
    unfolding nc_def by simp
  finally show ?thesis .
qed


lemma kernel_one_mat: fixes A :: "'a :: field mat" and n :: nat
  defines A: "A  1m n"
  shows 
    "kernel.dim n A = 0"
    "kernel.basis n A {}"
proof -
  have Ac: "A  carrier_mat n n" unfolding A by auto
  have "pivot_fun A id n"
    unfolding A by (rule pivot_funI, auto)
  hence row: "row_echelon_form A" unfolding row_echelon_form_def A by auto
  have "{i. i < n  row A i  0v n} = {0 ..< n}" unfolding A by auto
  hence id: "card {i. i < n  row A i  0v n} = n" by auto
  interpret kernel n n A by (unfold_locales, rule Ac)
  from find_base_vectors[OF row Ac, unfolded id]
  show "dim = 0" "basis {}" by auto
qed

lemma kernel_upper_triangular: assumes A: "A  carrier_mat n n"
  and ut: "upper_triangular A" and 0: "0  set (diag_mat A)"
  shows "kernel.dim n A = 0" "kernel.basis n A {}"
proof -
  define ma where "ma = diag_mat A"
  from det_upper_triangular[OF ut A] have "det A = prod_list (diag_mat A)" .
  also have "  0" using 0 unfolding ma_def[symmetric]
    by (induct ma, auto)
  finally have "det A  0" .
  from det_non_zero_imp_unit[OF A this, unfolded Units_def, of "()"]
    obtain B where B: "B  carrier_mat n n" and BA: "B * A = 1m n" and AB: "A * B = 1m n"
    by (auto simp: ring_mat_def)
  from mat_kernel_mult_eq[OF A B A AB, unfolded BA]
  have id: "mat_kernel A = mat_kernel (1m n)" ..
  show "kernel.dim n A = 0" "kernel.basis n A {}"
    unfolding id by (rule kernel_one_mat)+
qed

lemma kernel_basis_exists: assumes A: "A  carrier_mat nr nc"
  shows " B. finite B  kernel.basis nc A B"
proof -
  obtain C where gj: "gauss_jordan_single A = C" by auto
  from gauss_jordan_single[OF A gj]
  obtain P Q where CPA: "C = P * A" and QP: "Q * P = 1m nr"
    and P: "P  carrier_mat nr nr" and Q: "Q  carrier_mat nr nr"   
    and C: "C  carrier_mat nr nc" and row: "row_echelon_form C"
    by auto
  from find_base_vectors[OF row C] have " B. finite B  kernel.basis nc C B" by blast
  also have "mat_kernel C = mat_kernel A" unfolding CPA
    by (rule mat_kernel_mult_eq[OF A P Q QP])
  finally show ?thesis .
qed


lemma mat_kernel_mult_right_gen_set: assumes A: "A  carrier_mat nr nc"
  and B: "B  carrier_mat nc nc"
  and C: "C  carrier_mat nc nc"
  and inv: "B * C = 1m nc"
  and gen_set: "kernel.gen_set nc (A * B) gen" and gen: "gen  mat_kernel (A * B)"
  shows "kernel.gen_set nc A (((*v) B) ` gen)" "(*v) B ` gen  mat_kernel A" "card (((*v) B) ` gen) = card gen"
proof -
  let ?AB = "A * B"
  let ?gen = "((*v) B) ` gen"
  from A B have AB: "A * B  carrier_mat nr nc" by auto
  from B have dimB: "dim_row B = nc" by auto
  from inv B C have CB: "C * B = 1m nc" by (metis mat_mult_left_right_inverse)
  interpret AB: kernel nr nc ?AB 
    by (unfold_locales, rule AB)
  interpret A: kernel nr nc A
    by (unfold_locales, rule A)
  {
    fix w
    assume "w  ?gen"
    then obtain v where w: "w = B *v v" and v: "v  gen" by auto
    from v have "v  mat_kernel ?AB" using gen by auto
    hence v: "v  carrier_vec nc" and 0: "?AB *v v = 0v nr" unfolding mat_kernel[OF AB] by auto
    have "?AB *v v = A *v w" unfolding w using v A B by simp
    with 0 have 0: "A *v w = 0v nr" by auto
    from w B v have w: "w  carrier_vec nc" by auto
    from 0 w have "w  mat_kernel A" unfolding mat_kernel[OF A] by auto
  } 
  thus genn: "?gen  mat_kernel A" by auto
  hence one_dir: "A.span ?gen  mat_kernel A" by fastforce
  {
    fix v v'
    assume v: "v  gen" and v': "v'  gen" and id: "B *v v = B *v v'"
    from v v' have v: "v  carrier_vec nc" and v': "v'  carrier_vec nc" 
      using gen unfolding mat_kernel[OF AB] by auto
    from arg_cong[OF id, of "λ v. C *v v"]
    have "v = v'" using v v'
      unfolding assoc_mult_mat_vec[symmetric, OF C B v] 
        assoc_mult_mat_vec[symmetric, OF C B v'] CB
      by auto
  } note inj = this
  hence inj_gen: "inj_on ((*v) B) gen" unfolding inj_on_def by auto
  show "card ?gen = card gen" using inj_gen by (rule card_image)
  {
    fix v
    let ?Cv = "C *v v"
    assume "v  mat_kernel A"
    from mat_kernelD[OF A this] have v: "v  carrier_vec nc" and 0: "A *v v = 0v nr" by auto
    have "?AB *v ?Cv = (A * (B * C)) *v v" using A B C v 
      by (subst assoc_mult_mat_vec[symmetric, OF AB C v], subst assoc_mult_mat[OF A B C], simp)
    also have " = 0v nr" unfolding inv using 0 A v by simp
    finally have 0: "?AB *v ?Cv = 0v nr" and Cv: "?Cv  carrier_vec nc" using C v by auto
    hence "?Cv  mat_kernel ?AB" unfolding mat_kernel[OF AB] by auto
    with gen_set have "?Cv  AB.span gen" by auto
    from this[unfolded AB.Ker.span_def] obtain a gen' where 
      Cv: "?Cv = AB.lincomb a gen'" and sub: "gen'  gen" and fin: "finite gen'" by auto
    let ?gen' = "((*v) B) ` gen'"
    from sub gen have gen': "gen'  mat_kernel ?AB" by auto
    have lin1: "AB.lincomb a gen'  carrier_vec nc"
      using AB.Ker.lincomb_closed[OF gen', of a]
      unfolding mat_kernel[OF AB] by (auto simp: class_field_def)
    hence dim1: "dim_vec (AB.lincomb a gen') = nc" by auto
    hence dim1b: "dim_vec (B *v (AB.Ker.lincomb a gen')) = nc" using B by auto
    from genn sub have genn': "?gen'  mat_kernel A" by auto
    from gen sub have gen'nc: "gen'  carrier_vec nc" unfolding mat_kernel[OF AB] by auto
    define a' where "a' = (λ b. a (C *v b))"
    from A.Ker.lincomb_closed[OF genn']
    have lin2: "A.Ker.lincomb a' ?gen'  carrier_vec nc"
      unfolding mat_kernel[OF A] by (auto simp: class_field_def)
    hence dim2: "dim_vec (A.Ker.lincomb a' ?gen') = nc" by auto
    have "v = B *v ?Cv" 
      by (unfold assoc_mult_mat_vec[symmetric, OF B C v] inv, insert v, simp)
    hence "v = B *v AB.Ker.lincomb a gen'" unfolding Cv by simp
    also have " = A.Ker.lincomb a' ?gen'"
    proof (rule eq_vecI; unfold dim1 dim1b dim2)
      fix i
      assume i: "i < nc"
      with dimB have ii: "i < dim_row B" by auto
      from sub inj have inj: "inj_on ((*v) B) gen'" unfolding inj_on_def by auto
      {
        fix v
        assume "v  gen'"
        with gen'nc have v: "v  carrier_vec nc" by auto
        hence "a' (B *v v) = a v" unfolding a'_def assoc_mult_mat_vec[symmetric, OF C B v] CB by auto
      } note a' = this
      have "A.Ker.lincomb a' ?gen' $ i = (v(*v) B ` gen'. a' v * v $ i)"
        unfolding A.lincomb_index[OF i genn']  by simp
      also have " = (vgen'. a v * ((B *v v) $ i))"
        by (rule sum.reindex_cong[OF inj refl], auto simp: a')
      also have " = (vgen'. (j = 0..< nc. a v * row B i $ j * v $ j))"
        unfolding mult_mat_vec_def dimB scalar_prod_def index_vec[OF i]
        by (rule sum.cong, insert gen'nc, auto simp: sum_distrib_left ac_simps)
      also have " = (j = 0 ..< nc. (v  gen'. a v * row B i $ j * v $ j))"
        by (rule sum.swap)
      also have " = (j = 0..<nc. row B i $ j * (vgen'. a v * v $ j))"
        by (rule sum.cong, auto simp: sum_distrib_left ac_simps)
      also have " = (B *v AB.Ker.lincomb a gen') $ i"
        unfolding index_mult_mat_vec[OF ii]
        unfolding scalar_prod_def dim1
        by (rule sum.cong[OF refl], subst AB.lincomb_index[OF _ gen'], auto)
      finally show "(B *v AB.Ker.lincomb a gen') $ i = A.Ker.lincomb a' ?gen' $ i" ..
    qed auto
    finally have "v  A.Ker.span ?gen" using sub fin
      unfolding A.Ker.span_def by (auto simp: class_field_def intro!: exI[of _ a'] exI[of _ ?gen'])
  }
  hence other_dir: "A.Ker.span ?gen  mat_kernel A" by fastforce
  from one_dir other_dir show "kernel.gen_set nc A (((*v) B) ` gen)" by auto
qed

lemma mat_kernel_mult_right_basis: assumes A: "A  carrier_mat nr nc"
  and B: "B  carrier_mat nc nc"
  and C: "C  carrier_mat nc nc"
  and inv: "B * C = 1m nc"
  and fin: "finite gen"
  and basis: "kernel.basis nc (A * B) gen"
  shows "kernel.basis nc A (((*v) B) ` gen)" 
  "card (((*v) B) ` gen) = card gen"
proof -
  let ?AB = "A * B"
  let ?gen = "((*v) B) ` gen"
  from A B have AB: "?AB  carrier_mat nr nc" by auto
  from B have dimB: "dim_row B = nc" by auto
  from inv B C have CB: "C * B = 1m nc" by (metis mat_mult_left_right_inverse)
  interpret AB: kernel nr nc ?AB 
    by (unfold_locales, rule AB)
  interpret A: kernel nr nc A
    by (unfold_locales, rule A)
  from basis[unfolded AB.Ker.basis_def] have gen_set: "AB.gen_set gen" and genAB: "gen  mat_kernel ?AB" by auto
  from mat_kernel_mult_right_gen_set[OF A B C inv gen_set genAB]
  have gen: "A.gen_set ?gen" and sub: "?gen  mat_kernel A" and card: "card ?gen = card gen" .
  from card show "card ?gen = card gen" .
  from fin have fing: "finite ?gen" by auto
  from gen have gen: "A.Ker.span ?gen = mat_kernel A" by auto
  have ABC: "A * B * C = A" using A B C inv by simp
  from kernel_basis_exists[OF A] obtain bas where finb: "finite bas" and bas: "A.basis bas" by auto
  from bas have bas': "A.gen_set bas" "bas  mat_kernel A" unfolding A.Ker.basis_def by auto
  let ?bas = "(*v) C ` bas"
  from mat_kernel_mult_right_gen_set[OF AB C B CB, unfolded ABC, OF bas']
  have bas': "?bas  mat_kernel ?AB" "AB.Ker.span ?bas = mat_kernel ?AB" "card ?bas = card bas" by auto
  from finb bas have cardb: "A.dim = card bas" by (rule A.Ker.dim_basis)
  from fin basis have cardg: "AB.dim = card gen" by (rule AB.Ker.dim_basis)
  from AB.Ker.gen_ge_dim[OF _ bas'(1-2)] finb bas'(3) cardb cardg
  have ineq1: "card gen  A.dim" by auto
  from A.Ker.dim_gen_is_basis[OF fing sub gen, unfolded card, OF this]
  show "A.basis ?gen" .
qed  
  
  
lemma mat_kernel_dim_mult_eq_right: assumes A: "A  carrier_mat nr nc"
  and B: "B  carrier_mat nc nc"
  and C: "C  carrier_mat nc nc"
  and BC: "B * C = 1m nc"
  shows "kernel.dim nc (A * B) = kernel.dim nc A"
proof -
  let ?AB = "A * B"
  from A B have AB: "?AB  carrier_mat nr nc" by auto
  interpret AB: kernel nr nc ?AB 
    by (unfold_locales, rule AB)
  interpret A: kernel nr nc A
    by (unfold_locales, rule A)
  from kernel_basis_exists[OF AB] obtain bas where finb: "finite bas" and bas: "AB.basis bas" by auto
  let ?bas = "((*v) B) ` bas"
  from mat_kernel_mult_right_basis[OF A B C BC finb bas] finb
  have bas': "A.basis ?bas" and finb': "finite ?bas" and card: "card ?bas = card bas" by auto
  show "AB.dim = A.dim" unfolding A.Ker.dim_basis[OF finb' bas'] AB.Ker.dim_basis[OF finb bas] card ..
qed


locale vardim =
  fixes f_ty :: "'a :: field itself"
begin

abbreviation "M == λk. module_vec TYPE('a) k"

abbreviation "span == λk. LinearCombinations.module.span class_ring (M k)"
abbreviation "lincomb == λk. module.lincomb (M k)"
abbreviation "lin_dep == λk. module.lin_dep class_ring (M k)"
abbreviation "padr m v == v @v 0v m"
definition "unpadr m v == vec (dim_vec v - m) (λi. v $ i)"
abbreviation "padl m v == 0v m @v v"
definition "unpadl m v == vec (dim_vec v - m) (λi. v $ (m+i))"

lemma unpadr_padr[simp]: "unpadr m (padr m v) = v" unfolding unpadr_def by auto
lemma unpadl_padl[simp]: "unpadl m (padl m v) = v" unfolding unpadl_def by auto

lemma padr_unpadr[simp]: "v : padr m ` U  padr m (unpadr m v) = v" by auto
lemma padl_unpadl[simp]: "v : padl m ` U  padl m (unpadl m v) = v" by auto

(* somehow not automatically proven *)
lemma padr_image:
  assumes "U  carrier_vec n" shows "padr m ` U  carrier_vec (n + m)"
proof(rule subsetI)
  fix v assume "v : padr m ` U"
  then obtain u where "u : U" and vmu: "v = padr m u" by auto
  hence "u : carrier_vec n" using assms by auto
  thus "v : carrier_vec (n + m)"
    unfolding vmu
    using zero_carrier_vec[of m] append_carrier_vec by metis
qed
lemma padl_image:
  assumes "U  carrier_vec n" shows "padl m ` U  carrier_vec (m + n)"
proof(rule subsetI)
  fix v assume "v : padl m ` U"
  then obtain u where "u : U" and vmu: "v = padl m u" by auto
  hence "u : carrier_vec n" using assms by auto
  thus "v : carrier_vec (m + n)"
    unfolding vmu
    using zero_carrier_vec[of m] append_carrier_vec by metis
qed

lemma padr_inj:
  shows "inj_on (padr m) (carrier_vec n :: 'a vec set)"
  apply(intro inj_onI) using append_vec_eq by auto

lemma padl_inj:
  shows "inj_on (padl m) (carrier_vec n :: 'a vec set)"
  apply(intro inj_onI)
  using append_vec_eq[OF zero_carrier_vec zero_carrier_vec] by auto

lemma lincomb_pad:
  fixes m n a
  assumes U: "(U :: 'a vec set)  carrier_vec n"
      and finU: "finite U"
  defines "goal pad unpad W == pad m (lincomb n a W) = lincomb (n+m) (a o unpad m) (pad m ` W)"
  shows "goal padr unpadr U" (is ?R) and "goal padl unpadl U" (is "?L")
proof -
  interpret N: vectorspace class_ring "M n" using vec_vs.
  interpret NM: vectorspace class_ring "M (n+m)" using vec_vs.
  note [simp] = module_vec_simps class_ring_simps
  have "?R  ?L" using finU U
  proof (induct set:finite)
    case empty thus ?case
      unfolding goal_def unfolding N.lincomb_def NM.lincomb_def by auto next
    case (insert u U)
      hence finU: "finite U"
        and U: "U  carrier_vec n"
        and u[simp]: "u : carrier_vec n"
        and uU: "u  U"
        and auU: "a : insert u U  UNIV"
        and aU: "a : U  UNIV"
        and au: "a u : UNIV"
        by auto
      have IHr: "goal padr unpadr U" and IHl: "goal padl unpadl U"
        using insert(3) U aU by auto
      note N_lci = N.lincomb_insert2[unfolded module_vec_simps]
      note NM_lci = NM.lincomb_insert2[unfolded module_vec_simps]
      have auu[simp]: "a u v u : carrier_vec n" using au u by simp
      have laU[simp]: "lincomb n a U : carrier_vec n"
        using N.lincomb_closed[unfolded module_vec_simps class_ring_simps, OF U aU].
      let ?m0 = "0v m :: 'a vec"
      have m0: "?m0 : carrier_vec m" by auto
      have ins: "lincomb n a (insert u U) = a u v u + lincomb n a U"
        using N_lci[OF finU U] auU uU u by auto
      show ?case
      proof
        have "padr m (a u v u + lincomb n a U) =
          (a u v u + lincomb n a U) @v (?m0 + ?m0)" by auto
        also have "... = padr m (a u v u) + padr m (lincomb n a U)"
          using append_vec_add[symmetric, OF auu laU]
          using zero_carrier_vec[of m] by metis
        also have "padr m (lincomb n a U) = lincomb (n+m) (a o unpadr m) (padr m ` U)"
          using IHr unfolding goal_def.
        also have "padr m (a u v u) = a u v padr m u" by auto
        also have "... = (a o unpadr m) (padr m u) v padr m u" by auto
        also have "... + lincomb (n+m) (a o unpadr m) (padr m ` U) =
          lincomb (n+m) (a o unpadr m) (insert (padr m u) (padr m ` U))"
          apply(subst NM_lci[symmetric])
          using finU uU U append_vec_eq[OF u] by auto
        also have "insert (padr m u) (padr m ` U) = padr m ` insert u U"
          by auto
        finally show "goal padr unpadr (insert u U)" unfolding goal_def ins.
        have [simp]: "n+m = m+n" by auto
        have "padl m (a u v u + lincomb n a U) =
          (?m0 + ?m0) @v (a u v u + lincomb n a U)" by auto
        also have "... = padl m (a u v u) + padl m (lincomb n a U)"
          using append_vec_add[symmetric, OF _ _ auu laU]
          using zero_carrier_vec[of m] by metis
        also have "padl m (lincomb n a U) = lincomb (n+m) (a o unpadl m) (padl m ` U)"
          using IHl unfolding goal_def.
        also have "padl m (a u v u) = a u v padl m u" by auto
        also have "... = (a o unpadl m) (padl m u) v padl m u" by auto
        also have "... + lincomb (n+m) (a o unpadl m) (padl m ` U) =
          lincomb (n+m) (a o unpadl m) (insert (padl m u) (padl m ` U))"
          apply(subst NM_lci[symmetric])
          using finU uU U append_vec_eq[OF m0] by auto
        also have "insert (padl m u) (padl m ` U) = padl m ` insert u U"
          by auto
        finally show "goal padl unpadl (insert u U)" unfolding goal_def ins.
      qed
  qed
  thus ?R ?L by auto
qed

lemma span_pad:
  assumes U: "(U::'a vec set)  carrier_vec n"
  defines "goal pad m == pad m ` span n U = span (n+m) (pad m ` U)"
  shows "goal padr m" "goal padl m"
proof -
  interpret N: vectorspace class_ring "M n" using vec_vs.
  interpret NM: vectorspace class_ring "M (n+m)" using vec_vs.
  { fix pad :: "'a vec  'a vec" and unpad :: "'a vec  'a vec"
    assume main: "A a. A  U  finite A 
      pad (lincomb n a A) = lincomb (n+m) (a o unpad) (pad ` A)"
    assume [simp]: "v. unpad (pad v) = v"
    assume pU: "pad ` U  carrier_vec (n+m)"
    have "pad ` (span n U) = span (n+m) (pad ` U)"
    proof (intro Set.equalityI subsetI)
      fix x assume "x : pad ` (span n U)"
      then obtain v where "v : span n U" and xv: "x = pad v" by auto
      then obtain a A
        where AU: "A  U" and finA: "finite A" and a: "a : A  UNIV"
          and vaA: "v = lincomb n a A"
        unfolding N.span_def by auto
      hence A: "A  carrier_vec n" using U by auto
      show "x : span (n+m) (pad ` U)" unfolding NM.span_def
      proof (intro CollectI exI conjI)
        show "x = lincomb (n+m) (a o unpad) (pad ` A)"
          using xv vaA main[OF AU finA] by auto
        show "pad ` A  pad ` U" using AU by auto
      qed (insert finA, auto simp: class_ring_simps)
      next
      fix x assume "x : span (n+m) (pad ` U)"
      then obtain a' A'
        where A'U: "A'  pad ` U" and finA': "finite A'" and a': "a' : A'  UNIV"
          and xa'A': "x = lincomb (n+m) a' A'"
        unfolding NM.span_def by auto
      then obtain A where finA: "finite A" and AU: "A  U" and A'A: "A' = pad ` A"
        using finite_subset_image[OF finA' A'U] by auto
      hence A: "A  carrier_vec n" using U by auto
      have A': "A'  carrier_vec (n+m)" using A'U pU by auto
      define a where "a = a' o pad"
      define a'' where "a'' = (a' o pad) o unpad"
      have a: "a : A  UNIV" by auto
      have restr: "restrict a' A' = restrict a'' A'"
      proof(rule restrict_ext)
        fix u' assume "u' : A'"
        then obtain u where "u : A" and "u' = pad u" unfolding A'A by auto
        thus "a' u' = a'' u'" unfolding a''_def a_def by auto
      qed
      have "x = lincomb (n+m) a' A'" using xa'A' unfolding A'A.
      also have "... = lincomb (n+m) a'' A'"
        apply (subst NM.lincomb_restrict)
        using finA' A' restr by (auto simp: module_vec_simps class_ring_simps)
      also have "... = lincomb (n+m) a'' (pad ` A)" unfolding A'A..
      also have "... = pad (lincomb n a A)"
        unfolding a''_def using main[OF AU finA] unfolding a_def by auto
      finally show "x : pad ` (span n U)" unfolding N.span_def
      apply(rule image_eqI, intro CollectI exI conjI)
        using finA AU by (auto simp: class_ring_simps)
    qed
  }
  note main = this
  have AUC: "A. A  U  A  carrier_vec n" using U by simp
  have [simp]: "n+m = m+n" by auto
  show "goal padr m" unfolding goal_def
    apply (subst main[OF _ _ padr_image[OF U]])
    using lincomb_pad[OF AUC] unpadr_padr by auto
  show "goal padl m" unfolding goal_def
    apply (subst main)
    using lincomb_pad[OF AUC] unpadl_padl padl_image[OF U] by auto
qed

lemma kernel_padr:
  assumes aA: "a : mat_kernel (A :: 'a :: field mat)"
      and A: "A : carrier_mat nr1 nc1"
      and B: "B : carrier_mat nr1 nc2"
      and D: "D : carrier_mat nr2 nc2"
  shows "padr nc2 a : mat_kernel (four_block_mat A B (0m nr2 nc1) D)" (is "_ : mat_kernel ?ABCD")
  unfolding mat_kernel_def
proof (rule, intro conjI)
  have [simp]: "dim_row A = nr1" "dim_row D = nr2" "dim_row ?ABCD = nr1 + nr2" using A D by auto
  have a: "a : carrier_vec nc1" using mat_kernel_carrier[OF A] aA by auto
  show "?ABCD *v padr nc2 a = 0v (dim_row ?ABCD)" (is "?l = ?r")
  proof
    fix i assume i: "i < dim_vec ?r"
    hence "?l $ i = row ?ABCD i  padr nc2 a" by auto
    also have "... = 0"
    proof (cases "i < nr1")
      case True
        hence rows: "row A i : carrier_vec nc1" "row B i : carrier_vec nc2"
          using A B by auto
        have "row ?ABCD i = row A i @v row B i"
          using row_four_block_mat(1)[OF A B _ D True] by auto
        also have "...  padr nc2 a = row A i  a + row B i  0v nc2"
          using scalar_prod_append[OF rows] a by auto
        also have "row A i  a = (A *v a) $ i" using True A by auto
        also have "... = 0" using mat_kernelD[OF A aA] True by auto
        also have "row B i  0v nc2 = 0" using True rows by auto
        finally show ?thesis by simp
      next case False
        let ?C = "0m nr2 nc1"
        let ?i = "i - nr1"
        have rows:
            "row ?C ?i : carrier_vec nc1" "row D ?i : carrier_vec nc2"
          using D i False A by auto
        have "row ?ABCD i = row ?C ?i @v row D ?i"
          using row_four_block_mat(2)[OF A B _ D False] i A D by auto
        also have "...  padr nc2 a = row ?C ?i  a + row D ?i  0v nc2"
          using scalar_prod_append[OF rows] a by auto
        also have "row ?C ?i  a = 0v nc1  a" using False A i by auto
        also have "... = 0" using a by auto
        also have "row D ?i  0v nc2 = 0" using False rows by auto
        finally show ?thesis by simp
    qed
    finally show "?l $ i = ?r $ i" using i by auto
  qed auto
  show "padr nc2 a : carrier_vec (dim_col ?ABCD)" using a A D by auto
qed

lemma kernel_padl:
  assumes dD: "d  mat_kernel (D :: 'a :: field mat)"
      and A: "A  carrier_mat nr1 nc1"
      and C: "C  carrier_mat nr2 nc1"
      and D: "D  carrier_mat nr2 nc2"
  shows "padl nc1 d  mat_kernel (four_block_mat A (0m nr1 nc2) C D)" (is "_  mat_kernel ?ABCD")
  unfolding mat_kernel_def
proof (rule, intro conjI)
  have [simp]: "dim_row A = nr1" "dim_row D = nr2" "dim_row ?ABCD = nr1 + nr2" using A D by auto
  have d: "d : carrier_vec nc2" using mat_kernel_carrier[OF D] dD by auto
  show "?ABCD *v padl nc1 d = 0v (dim_row ?ABCD)" (is "?l = ?r")
  proof
    fix i assume i: "i < dim_vec ?r"
    hence "?l $ i = row ?ABCD i  padl nc1 d" by auto
    also have "... = 0"
    proof (cases "i < nr1")
      case True
        let ?B = "0m nr1 nc2"
        have rows: "row A i : carrier_vec nc1" "row ?B i : carrier_vec nc2"
          using A True by auto
        have "row ?ABCD i = row A i @v row ?B i"
          using row_four_block_mat(1)[OF A _ C D True] by auto
        also have "...  padl nc1 d = row A i  0v nc1 + row ?B i  d"
          using scalar_prod_append[OF rows] d by auto
        also have "row A i  0v nc1 = 0" using A True by auto
        also have "row ?B i  d = 0" using True d by auto
        finally show ?thesis by simp
      next case False
        let ?i = "i - nr1"
        have rows:
            "row C ?i : carrier_vec nc1" "row D ?i : carrier_vec nc2"
          using C D i False A by auto
        have "row ?ABCD i = row C ?i @v row D ?i"
          using row_four_block_mat(2)[OF A _ C D False] i A D by auto
        also have "...  padl nc1 d = row C ?i  0v nc1 + row D ?i  d"
          using scalar_prod_append[OF rows] d by auto
        also have "row C ?i  0v nc1 = 0" using False A C i by auto
        also have "row D ?i  d = (D *v d) $ ?i" using D d False i by auto
        also have "... = 0" using mat_kernelD[OF D dD] using False i by auto
        finally show ?thesis by simp
    qed
    finally show "?l $ i = ?r $ i" using i by auto
  qed auto
  show "padl nc1 d : carrier_vec (dim_col ?ABCD)" using d A D by auto
qed

lemma mat_kernel_split:
  assumes A: "A  carrier_mat n n"
      and D: "D  carrier_mat m m"
      and kAD: "k  mat_kernel (four_block_mat A (0m n m) (0m m n) D)"
           (is "_  mat_kernel ?A00D")
  shows "vec_first k n  mat_kernel A" (is "?a  _")
    and "vec_last k m  mat_kernel D" (is "?d  _")
proof -
  have "0v n @v 0v m = 0v (n+m)" by auto
  also
    have A00D: "?A00D : carrier_mat (n+m) (n+m)" using four_block_carrier_mat[OF A D].
    hence k: "k : carrier_vec (n+m)" using kAD mat_kernel_carrier by auto
    hence "?a @v ?d = k" by simp
    hence "0v (n+m) = ?A00D *v (?a @v ?d)" using mat_kernelD[OF A00D] kAD by auto
  also have "... = A *v ?a @v D *v ?d"
    using mult_mat_vec_split[OF A D] by auto
  finally have "0v n @v 0v m = A *v ?a @v D *v ?d".
  hence "0v n = A *v ?a  0v m = D *v ?d"
    apply(subst append_vec_eq[of _ n, symmetric]) using A D by auto
  thus "?a : mat_kernel A" "?d : mat_kernel D" unfolding mat_kernel_def using A D by auto
qed

lemma padr_padl_eq:
  assumes v: "v : carrier_vec n"
  shows "padr m v = padl n u  v = 0v n  u = 0v m"
  apply (subst append_vec_eq) using v by auto


lemma pad_disjoint:
  assumes A: "A  carrier_vec n" and A0: "0v n  A" and B: "B  carrier_vec m"
  shows "padr m ` A  padl n ` B = {}" (is "?A  ?B = _")
proof (intro equals0I)
  fix ab assume "ab : ?A  ?B"
  then obtain a b
    where "ab = padr m a" "ab = padl n b" and dim: "a : A" "b : B" by force
  hence "padr m a = padl n b" by auto
  hence "a = 0v n" using dim A B by auto
  thus "False" using dim A0 by auto
qed

lemma padr_padl_lindep:
  assumes A: "A  carrier_vec n" and liA: "~ lin_dep n A"
      and B: "B  carrier_vec m" and liB: "~ lin_dep m B"
  shows "~ lin_dep (n+m) (padr m ` A  padl n ` B)" (is "~ lin_dep _ (?A  ?B)")
proof -
  interpret N: vectorspace class_ring "M n" using vec_vs.
  interpret M: vectorspace class_ring "M m" using vec_vs.
  interpret NM: vectorspace class_ring "M (n+m)" using vec_vs.
  note [simp] = module_vec_simps class_ring_simps
  have AB: "?A  ?B  carrier_vec (n+m)"
    using padr_image[OF A] padl_image[OF B] by auto
  show ?thesis
    unfolding NM.lin_dep_def
    unfolding not_ex not_imp[symmetric] not_not
  proof(intro allI impI)
    fix U f u
    assume finU: "finite U"
       and UAB: "U  ?A  ?B"
       and f: "f : U  carrier class_ring"
       and 0: "lincomb (n+m) f U = 𝟬M (n+m)"
       and uU: "u : U"
    let ?UA = "U  ?A" and ?UB = "U  ?B"
    have "?UA  ?A" "?UB  ?B" by auto
    then obtain A' B'
      where A'A: "A'  A" and B'B: "B'  B"
        and UAA': "?UA = padr m ` A'" and UBB': "?UB = padl n ` B'"
      unfolding subset_image_iff by auto
    hence A': "A'  carrier_vec n" and B': "B'  carrier_vec m" using A B by auto
    have finA': "finite A'" and finB': "finite B'"
    proof -
      have "padr m ` A'  U" "padl n ` B'  U" using UAA' UBB' by auto
      hence pre: "finite (padr m ` A')" "finite (padl n ` B')"
        using finite_subset[OF _ finU] by auto
      show "finite A'"
        apply (rule finite_imageD) using subset_inj_on[OF padr_inj A'] pre by auto
      show "finite B'"
        apply (rule finite_imageD) using subset_inj_on[OF padl_inj B'] pre by auto
    qed
    have "0v n  A" using N.zero_nin_lin_indpt[OF _ liA] A class_semiring.one_zeroI by auto
    hence "?A  ?B = {}" using pad_disjoint A B by auto
    hence disj: "?UA  ?UB = {}" by auto
    have split: "U = padr m ` A'  padl n ` B'"
      unfolding UAA'[symmetric] UBB'[symmetric] using UAB by auto
    show "f u = 𝟬(class_ring::'a ring)"
    proof -
      let ?a = "f  padr m"
      let ?b = "f  padl n"
      have lcA': "lincomb n ?a A' : carrier_vec n" using N.lincomb_closed A' by auto
      have lcB': "lincomb m ?b B' : carrier_vec m" using M.lincomb_closed B' by auto
  
      have "0v n @v 0v m = 0v (n+m)" by auto
      also have "... = lincomb (n+m) f U" using 0 by auto
      also have "U = ?UA  ?UB" using UAB by auto
      also have "lincomb (n+m) f ... = lincomb (n+m) f ?UA + lincomb (n+m) f ?UB"
        apply(subst NM.lincomb_union) using A B finU disj by auto
      also have "lincomb (n+m) f ?UA = lincomb (n+m) (restrict f ?UA) ?UA"
        apply (subst NM.lincomb_restrict) using A finU by auto
      also have "restrict f ?UA = restrict (?a  unpadr m) ?UA"
        apply(rule restrict_ext) by auto
      also have "lincomb (n+m) ... ?UA = lincomb (n+m) (?a  unpadr m) ?UA"
        apply(subst NM.lincomb_restrict) using A finU by auto
      also have "?UA = padr m ` A'" using UAA'.
      also have "lincomb (n+m) (?a  unpadr m) ... =
        padr m (lincomb n ?a A')"
        using lincomb_pad(1)[OF A' finA',symmetric].
      also have "lincomb (n+m) f ?UB = lincomb (n+m) (restrict f ?UB) ?UB"
        apply (subst NM.lincomb_restrict) using B finU by auto
      also have "restrict f ?UB = restrict (?b  unpadl n) ?UB"
        apply(rule restrict_ext) by auto
      also have "lincomb (n+m) ... ?UB = lincomb (n+m) (?b  unpadl n) ?UB"
        apply(subst NM.lincomb_restrict) using B finU by auto
      also have "n+m = m+n" by auto
      also have "?UB = padl n ` B'" using UBB'.
      also have "lincomb (m+n) (?b  unpadl n) ... =
        padl n (lincomb m ?b B')"
        using lincomb_pad(2)[OF B' finB',symmetric].
      also have "padr m (lincomb n ?a A') + ... =
          (lincomb n ?a A' + 0v n) @v (0v m + lincomb m ?b B')"
        apply (rule append_vec_add) using lcA' lcB' by auto
      also have "... = lincomb n ?a A' @v lincomb m ?b B'" using lcA' lcB' by auto
      finally have "0v n @v 0v m = lincomb n ?a A' @v lincomb m ?b B'".
      hence "0v n = lincomb n ?a A'  0v m = lincomb m ?b B'"
        apply(subst append_vec_eq[symmetric]) using lcA' lcB' by auto
      from conjunct1[OF this] conjunct2[OF this]
      have "?a : A'  {0}" "?b : B'  {0}"
        using N.not_lindepD[OF liA finA' A'A]
        using M.not_lindepD[OF liB finB' B'B] by auto
      hence "f : padr m ` A'  {0}" "f : padl n ` B'  {0}" by auto
      hence "f : padr m ` A'  padl n ` B'  {0}" by auto
      hence "f : U  {0}" using split by auto
      hence "f u = 0" using uU by auto
      thus ?thesis by simp
    qed
  qed
qed

end

lemma kernel_four_block_0_mat:
  assumes Adef: "(A :: 'a::field mat) = four_block_mat B (0m n m) (0m m n) D"
  and B: "B  carrier_mat n n"
  and D: "D  carrier_mat m m"
  shows "kernel.dim (n + m) A = kernel.dim n B + kernel.dim m D"
proof -
  have [simp]: "n + m = m + n" by auto
  have A: "A  carrier_mat (n+m) (n+m)"
    using Adef four_block_carrier_mat[OF B D] by auto
  interpret vardim "TYPE('a)".
  interpret MN: vectorspace class_ring "M (n+m)" using vec_vs.
  interpret KA: kernel "n+m" "n+m" A by (unfold_locales, rule A)
  interpret KB: kernel n n B by (unfold_locales, rule B)
  interpret KD: kernel m m D by (unfold_locales, rule D)

  note [simp] = module_vec_simps

  from kernel_basis_exists[OF B]
    obtain baseB where fin_bB: "finite baseB" and bB: "KB.basis baseB" by blast
  hence bBkB: "baseB  mat_kernel B" unfolding KB.Ker.basis_def by auto
  hence bBc: "baseB  carrier_vec n" using mat_kernel_carrier[OF B] by auto
  have bB0: "0v n  baseB"
    using bB unfolding KB.Ker.basis_def
    using KB.Ker.vs_zero_lin_dep[OF bBkB] by auto
  have bBkA: "padr m ` baseB  mat_kernel A"
  proof
    fix a assume "a : padr m ` baseB"
    then obtain b where ab: "a = padr m b" and "b : baseB" by auto
    hence "b : mat_kernel B" using bB unfolding KB.Ker.basis_def by auto
    hence "padr m b : mat_kernel A"
      unfolding Adef using kernel_padr[OF _ B _ D] by auto
    thus "a : mat_kernel A" using ab by auto
  qed
  from kernel_basis_exists[OF D]
    obtain baseD where fin_bD: "finite baseD" and bD: "KD.basis baseD" by blast
  hence bDkD: "baseD  mat_kernel D" unfolding KD.Ker.basis_def by auto
  hence bDc: "baseD  carrier_vec m" using mat_kernel_carrier[OF D] by auto
  have bDkA: "padl n ` baseD  mat_kernel A"
  proof
    fix a assume "a : padl n ` baseD"
    then obtain d where ad: "a = padl n d" and "d : baseD" by auto
    hence "d : mat_kernel D" using bD unfolding KD.Ker.basis_def by auto
    hence "padl n d : mat_kernel A"
      unfolding Adef using kernel_padl[OF _ B _ D] by auto
    thus "a : mat_kernel A" using ad by auto
  qed
  let ?BD = "(padr m ` baseB  padl n ` baseD)"
  have finBD: "finite ?BD" using fin_bB fin_bD by auto
  have "KA.basis  ?BD"
    unfolding KA.Ker.basis_def
  proof (intro conjI Set.equalityI)
    show BDk: "?BD  mat_kernel A" using bBkA bDkA by auto
    also have "mat_kernel A  carrier_vec (m+n)" using mat_kernel_carrier A by auto
    finally have BD: "?BD  carrier (M (n + m))" by auto
    show "mat_kernel A  KA.Ker.span ?BD"
      unfolding KA.span_same[OF BDk]
    proof
      have BD: "?BD  carrier_vec (n+m)" (is "_  ?R")
      proof(rule)
        fix v assume "v : ?BD"
        moreover
        { assume "v : padr m ` baseB"
          then obtain b where "b : baseB" and vb: "v = padr m b" by auto
          hence "b : carrier_vec n" using bBc by auto
          hence "v : ?R" unfolding vb apply(subst append_carrier_vec) by auto
        }
        moreover
        { assume "v : padl n ` baseD"
          then obtain d where "d : baseD" and vd: "v = padl n d" by auto
          hence "d : carrier_vec m" using bDc by auto
          hence "v : ?R" unfolding vd apply(subst append_carrier_vec) by auto
        }
        ultimately show "v: ?R" by auto
      qed
      fix a assume a: "a : mat_kernel A"
      hence "a : carrier_vec (n+m)" using a mat_kernel_carrier[OF A] by auto
      hence "a = vec_first a n @v vec_last a m" (is "_ = ?b @v ?d") by simp
      also have "... = padr m ?b + padl n ?d" by auto
      finally have 1: "a = padr m ?b + padl n ?d".
  
      have subkernel: "?b : mat_kernel B" "?d : mat_kernel D"
        using mat_kernel_split[OF B D] a Adef by auto
      hence "?b : span n baseB"
        using bB unfolding KB.Ker.basis_def using KB.span_same by auto
      hence "padr m ?b : padr m ` span n baseB" by auto
      also have "padr m ` span n baseB = span (n+m) (padr m ` baseB)"
        using span_pad[OF bBc] by auto
      also have "...  span (n+m) ?BD" using MN.span_is_monotone by auto
      finally have 2: "padr m ?b : span (n+m) ?BD".
      have "?d : span m baseD"
        using subkernel bD unfolding KD.Ker.basis_def using KD.span_same by auto
      hence "padl n ?d : padl n ` span m baseD" by auto
      also have "padl n ` span m baseD = span (n+m) (padl n ` baseD)"
        using span_pad[OF bDc] by auto
      also have "...  span (n+m) ?BD" using MN.span_is_monotone by auto
      finally have 3: "padl n ?d : span (n+m) ?BD".
  
      have "padr m ?b + padl n ?d : span (n+m) ?BD"
        using MN.span_add1[OF _ 2 3] BD by auto
      thus "a  span (n+m) ?BD" using 1 by auto
    qed
    show "KA.Ker.span ?BD  mat_kernel A" using KA.Ker.span_closed[OF BDk] by auto
    have li: "~ lin_dep n baseB" "~ lin_dep m baseD"
      using bB[unfolded KB.Ker.basis_def]
      unfolding KB.lindep_same[OF bBkB]
      using bD[unfolded KD.Ker.basis_def]
      unfolding KD.lindep_same[OF bDkD] by auto
    show "~ KA.Ker.lin_dep ?BD"
      unfolding KA.lindep_same[OF BDk]
      apply(rule padr_padl_lindep) using bBc bDc li by auto
  qed
  hence "KA.dim = card ?BD" using KA.Ker.dim_basis[OF finBD] by auto
  also have "card ?BD = card (padr m ` baseB) + card (padl n ` baseD)"
    apply(rule card_Un_disjoint)
    using pad_disjoint[OF bBc bB0 bDc] fin_bB fin_bD by auto
  also have "... = card baseB + card baseD"
    using card_image[OF subset_inj_on[OF padr_inj]]
    using card_image[OF subset_inj_on[OF padl_inj]] bBc bDc by auto
  also have "card baseB = KB.dim" using KB.Ker.dim_basis[OF fin_bB] bB by auto
  also have "card baseD = KD.dim" using KD.Ker.dim_basis[OF fin_bD] bD by auto
  finally show ?thesis.

qed

lemma similar_mat_wit_kernel_dim: assumes A: "A  carrier_mat n n"
  and wit: "similar_mat_wit A B P Q"
  shows "kernel.dim n A = kernel.dim n B"
proof -
  from similar_mat_witD2[OF A wit]
  have QP: "Q * P = 1m n" and AB: "A = P * B * Q" and 
    A: "A  carrier_mat n n" and B: "B  carrier_mat n n" and P: "P  carrier_mat n n" and Q: "Q  carrier_mat n n" by auto
  from P B have PB: "P * B  carrier_mat n n" by auto
  show ?thesis unfolding AB mat_kernel_dim_mult_eq_right[OF PB Q P QP] mat_kernel_mult_eq[OF B P Q QP]
    by simp
qed


end

Theory Jordan_Normal_Form_Uniqueness

(*  
    Author:      René Thiemann 
                 Akihisa Yamada
    License:     BSD
*)
section ‹Jordan Normal Form -- Uniqueness›

text ‹We prove that the Jordan normal form of a matrix
  is unique up to permutations of the blocks. We do this 
  via generalized eigenspaces, and an algorithm which 
  computes for each potential jordan block (ev,n), how often
  it occurs in any Jordan normal form.›

theory Jordan_Normal_Form_Uniqueness
imports
  Jordan_Normal_Form
  Matrix_Kernel
begin

lemma similar_mat_wit_char_matrix: assumes wit: "similar_mat_wit A B P Q"
  shows "similar_mat_wit (char_matrix A ev) (char_matrix B ev) P Q"
proof -
  define n where "n = dim_row A"
  let ?C = "carrier_mat n n"
  from similar_mat_witD[OF refl wit, folded n_def] have
    A: "A  ?C" and B: "B  ?C" and P: "P  ?C" and Q: "Q  ?C"
    and PQ: "P * Q = 1m n" and QP: "Q * P = 1m n"
    and AB: "A = P * B * Q"
    by auto
  have "char_matrix A ev = (P * B * Q + (-ev) m (P * Q))"
    unfolding char_matrix_def n_def[symmetric] unfolding AB PQ 
    by (intro eq_matI, insert P B Q, auto)
  also have "(-ev) m (P * Q) = P * ((-ev) m 1m n) * Q" using P Q
    by (metis mult_smult_assoc_mat mult_smult_distrib one_carrier_mat right_mult_one_mat)
  also have "P * B * Q +  = (P * B + P * ((-ev) m 1m n)) * Q" using P B
    by (intro add_mult_distrib_mat[symmetric, OF _ _ Q, of _ n], auto)
  also have "P * B + P * ((-ev) m 1m n) = P * (B  + (-ev) m 1m n)"
    by (intro mult_add_distrib_mat[symmetric, OF P B], auto)
  also have "(B  + (-ev) m 1m n) = char_matrix B ev" unfolding char_matrix_def
    by (intro eq_matI, insert B, auto)
  finally have AB: "char_matrix A ev = P * char_matrix B ev * Q" .
  show "similar_mat_wit (char_matrix A ev) (char_matrix B ev) P Q"
    by (intro similar_mat_witI[OF PQ QP AB _ _ P Q], insert A B, auto)
qed

context fixes ty :: "'a :: field itself"
begin

lemma dim_kernel_non_zero_jordan_block_pow: assumes a: "a  0"
  shows "kernel.dim n (jordan_block n (a :: 'a) ^m k) = 0"
  by (rule kernel_upper_triangular[OF pow_carrier_mat[OF jordan_block_carrier]],
  unfold jordan_block_pow, insert a, auto simp: diag_mat_def)

lemma dim_kernel_zero_jordan_block_pow: 
  "kernel.dim n ((jordan_block n (0 :: 'a)) ^m k) = min k n" (is "kernel.dim _ ?A = ?c")
proof - 
  have A: "?A  carrier_mat n n" by auto
  hence dim: "dim_row ?A = n" by simp
  let ?f = "λ i. min (k + i) n"
  have piv: "pivot_fun ?A ?f n" unfolding jordan_block_zero_pow
    by (intro pivot_funI, auto)
  hence row: "row_echelon_form ?A" unfolding row_echelon_form_def by auto
  from find_base_vectors(5-6)[OF row A]
  have "kernel.dim n ?A = n - length (map fst (pivot_positions ?A))" by auto
  also have "length (map fst (pivot_positions ?A)) = card (fst ` set (pivot_positions ?A))"
    by (subst distinct_card[OF pivot_positions(2)[OF A piv], symmetric], simp)
  also have "fst ` set (pivot_positions ?A) = { 0 ..< (n - ?c)}" unfolding pivot_positions(1)[OF A piv]
    by force
  also have "card  = n - ?c" by simp
  finally show ?thesis by simp
qed
  
definition dim_gen_eigenspace :: "'a mat  'a  nat  nat" where
  "dim_gen_eigenspace A ev k = kernel_dim ((char_matrix A ev) ^m k)"

lemma dim_gen_eigenspace_jordan_matrix: 
  "dim_gen_eigenspace (jordan_matrix n_as) ev k
    = ( n  map fst [(n, e)n_as . e = ev]. min k n)"
proof -
  let ?JM = "λ n_as. jordan_matrix n_as"
  let ?CM = "λ n_as. char_matrix (?JM n_as) ev"
  let ?A = "λ n_as. (?CM n_as) ^m k"
  let ?n = "λ n_as. sum_list (map fst n_as)"
  let ?C = "λ n_as. carrier_mat (?n n_as) (?n n_as)"
  let ?sum = "λ n_as.  n  map fst [(n, e)n_as . e = ev]. min k n"
  let ?dim = "λ n_as. sum_list (map fst n_as)"
  let ?kdim = "λ n_as. kernel.dim (?dim n_as) (?A n_as)"
  have JM: " n_as. ?JM n_as  ?C n_as" by auto
  have CM: " n_as. ?CM n_as  ?C n_as" by auto
  have A: " n_as. ?A n_as  ?C n_as" by auto  
  have dimc: "dim_col (?JM n_as) = ?dim n_as" by simp
  interpret K: kernel "?dim n_as" "?dim n_as" "?A n_as"
    by (unfold_locales, rule A)
  show ?thesis unfolding dim_gen_eigenspace_def K.kernel_dim 
  proof (induct n_as)
    case Nil
    have "?JM Nil = 1m 0" unfolding jordan_matrix_def
      by (intro eq_matI, auto)
    hence id: "?A Nil = 1m 0" unfolding char_matrix_def by auto
    show ?case unfolding id using kernel_one_mat[of 0] by auto
  next
    case (Cons ne n_as')
    let ?n_as = "Cons ne n_as'"
    let ?d = "?dim ?n_as"
    let ?d' = "?dim n_as'"
    obtain n e where ne: "ne = (n,e)" by force
    have dim: "?d = n + ?d'" unfolding ne by simp
    let ?jb = "jordan_block n e"
    let ?cm = "char_matrix ?jb ev"
    let ?a = "?cm ^m k"
    have a: "?a  carrier_mat n n" by simp
    from JM[of n_as'] have dim_rec: "dim_row (?JM n_as') = ?d'" "dim_col (?JM n_as') = ?d'" by auto
    hence JM_id: "?JM ?n_as = four_block_mat ?jb (0m n ?d') (0m ?d' n) (?JM n_as')"
      unfolding ne jordan_matrix_def using JM[of n_as']
      by (simp add: Let_def)
    have CM_id: "?CM ?n_as = four_block_mat ?cm (0m n ?d') (0m ?d' n) (?CM n_as')"
      unfolding JM_id
      unfolding char_matrix_def
      by (intro eq_matI, auto)
    have A_id: "?A ?n_as = four_block_mat ?a (0m n ?d') (0m ?d' n) (?A n_as')"
      unfolding CM_id by (rule pow_four_block_mat[OF _ CM], auto)
    have kdim: "?kdim ?n_as = kernel.dim n ?a + ?kdim n_as'"
      unfolding dim A_id
      by (rule kernel_four_block_0_mat[OF refl a A])
    also have "?kdim n_as' = ?sum n_as'" by (rule Cons)
    also have "kernel.dim n ?a = (if e = ev then min k n else 0)"
      using dim_kernel_zero_jordan_block_pow[of n k]
        dim_kernel_non_zero_jordan_block_pow[of "e - ev" n k]
      unfolding char_matrix_jordan_block
      by (cases "e = ev", auto)
    also have " + ?sum n_as' = ?sum ?n_as" unfolding ne by auto
    finally show ?case .
  qed
qed

  
lemma dim_gen_eigenspace_similar: assumes sim: "similar_mat A B"
  shows "dim_gen_eigenspace A = dim_gen_eigenspace B"
proof (intro ext)
  fix ev k
  define n where "n = dim_row A"
  from sim[unfolded similar_mat_def] obtain P Q where
    wit: "similar_mat_wit A B P Q" by auto
  let ?C = "carrier_mat n n"
  from similar_mat_witD[OF refl wit, folded n_def]
    have A: "A  ?C" and B: "B  ?C" and P: "P  ?C" and Q: "Q  ?C" 
    and PQ: "P * Q = 1m n" and QP: "Q * P = 1m n"
    by auto
  from similar_mat_wit_pow[OF similar_mat_wit_char_matrix[OF wit, of ev], of k]
  have wit: "similar_mat_wit (char_matrix A ev ^m k) (char_matrix B ev ^m k) P Q" .
  from A B have cA: "char_matrix A ev ^m k  carrier_mat n n" 
    and cB: "char_matrix B ev ^m k  carrier_mat n n" by auto
  hence dim: "dim_col (char_matrix A ev ^m k) = n" "dim_col (char_matrix B ev ^m k) = n" by auto
  have "dim_gen_eigenspace A ev k = kernel_dim (char_matrix A ev ^m k)"
    unfolding dim_gen_eigenspace_def using A by simp
  also have " = kernel_dim (char_matrix B ev ^m k)" unfolding kernel_dim_def dim
    by (rule similar_mat_wit_kernel_dim[OF cA wit])
  also have " = dim_gen_eigenspace B ev k" 
    unfolding dim_gen_eigenspace_def using B by simp
  finally show "dim_gen_eigenspace A ev k = dim_gen_eigenspace B ev k" .
qed
  
lemma dim_gen_eigenspace: assumes "jordan_nf A n_as"
  shows "dim_gen_eigenspace A ev k
    = ( n  map fst [(n, e)n_as . e = ev]. min k n)"
proof -
  from assms[unfolded jordan_nf_def]
  have sim: "similar_mat A (jordan_matrix n_as)" by auto
  from dim_gen_eigenspace_jordan_matrix[of n_as, folded dim_gen_eigenspace_similar[OF this]]
  show ?thesis .
qed

definition compute_nr_of_jordan_blocks :: "'a mat  'a  nat  nat" where
  "compute_nr_of_jordan_blocks A ev k = 2 * dim_gen_eigenspace A ev k -
     dim_gen_eigenspace A ev (k - 1) - dim_gen_eigenspace A ev (Suc k)"

text ‹This lemma finally shows uniqueness of JNFs. Take an arbitrary
  JNF of a matrix $A$, (encoded by the list of Jordan-blocks @{term n_as}),
  then then number of occurrences of each Jordan-Block in @{term n_as} 
  is uniquely determined, namely by @{const compute_nr_of_jordan_blocks}. 
  The condition @{term "k  (0 :: nat)"}
  is to ensure that we do not count blocks of dimension 0.›

lemma compute_nr_of_jordan_blocks: assumes jnf: "jordan_nf A n_as"
  and no_0: "k  0"
  shows "compute_nr_of_jordan_blocks A ev k = length (filter ((=) (k,ev)) n_as)"
proof -
  from no_0 obtain k1 where k: "k = Suc k1" by (cases k, auto)
  let ?k = "Suc k1" let ?k2 = "Suc ?k"
  let ?dim = "dim_gen_eigenspace A ev"
  let ?sizes = "map fst [(n, e)n_as . e = ev]"
  define sizes where "sizes = ?sizes"
  let ?two = "length (filter ((=) (k, ev)) n_as)"
  have "compute_nr_of_jordan_blocks A ev k = 
    ?dim ?k + ?dim ?k - ?dim k1 - ?dim ?k2" unfolding compute_nr_of_jordan_blocks_def k by simp
  also have " = length (filter ((=) k) ?sizes)"
    unfolding dim_gen_eigenspace[OF jnf] k sizes_def[symmetric]
  proof (rule sym, induct sizes)
    case (Cons s sizes)
    show ?case
    proof (cases "s = ?k")
      case True
      let ?sum = "λ k sizes. sum_list (map (min k) sizes)"
      let ?len = "λ sizes. length (filter ((=) ?k) sizes)"
      have len: "?len (s # sizes) = Suc (?len sizes)" unfolding True by simp
      have IH: "?len sizes = ?sum ?k sizes + ?sum ?k sizes -
        ?sum k1 sizes - ?sum ?k2 sizes" by (rule Cons)
      have "?sum ?k (s # sizes) + ?sum ?k (s # sizes) -
        ?sum k1 (s # sizes) - ?sum ?k2 (s # sizes)
        = Suc (?sum ?k sizes + ?sum ?k sizes) - 
         (?sum k1 sizes + ?sum ?k2 sizes)"
        using True by simp
      also have " = Suc (?sum ?k sizes + ?sum ?k sizes - (?sum k1 sizes + ?sum ?k2 sizes))"
        by (rule Suc_diff_le, induct sizes, auto)
      also have " = ?len (s # sizes)" unfolding len IH by simp
      finally show ?thesis by simp
    qed (insert Cons, auto)
  qed simp
  also have " = length (filter ((=) (k, ev)) n_as)" by (induct n_as, force+)
  finally show ?thesis .
qed

definition compute_set_of_jordan_blocks :: "'a mat  'a  (nat × 'a)list" where
  "compute_set_of_jordan_blocks A ev  let 
     k = Polynomial.order ev (char_poly A);
     as = map (dim_gen_eigenspace A ev) [0 ..< Suc (Suc k)];
     cards = map (λ k. (k, 2 * as ! k - as ! (k - 1) - as ! Suc k)) [1 ..< Suc k]
     in map (λ (k,c). (k,ev)) (filter (λ (k,c). c  0) cards)"

lemma compute_set_of_jordan_blocks: assumes jnf: "jordan_nf A n_as"
  shows "set (compute_set_of_jordan_blocks A ev) = set n_as  UNIV × {ev}" (is "?C = ?N'")
proof -
  let ?N = "set n_as  UNIV × {ev} - {0} × UNIV" 
  have N: "?N' = ?N" using jnf[unfolded jordan_nf_def] by force
  note cjb = compute_nr_of_jordan_blocks[OF jnf]
  note d = compute_set_of_jordan_blocks_def Let_def
  define kk where "kk = Polynomial.order ev (char_poly A)"
  define as where "as = map (dim_gen_eigenspace A ev) [0 ..< Suc (Suc kk)]"
  define cards where "cards = map (λ k. (k, 2 * as ! k - as ! (k - 1) - as ! Suc k)) [1 ..< Suc kk]"
  have C: "?C = set (map (λ (k,c). (k,ev)) (filter (λ (k,c). c  0) cards))"
    unfolding d as_def kk_def cards_def by (rule refl)
  {
    fix i
    assume "i < Suc (Suc kk)"
    hence "as ! i = dim_gen_eigenspace A ev i"
      unfolding as_def by (auto simp del: upt_Suc)
  } note as = this
  (* TODO: perhaps use special code equation, and use inefficient thing in definition *)
  have cards: "cards = map (λ k. (k, compute_nr_of_jordan_blocks A ev k)) [1 ..< Suc kk]"
    unfolding cards_def
    by (rule map_cong[OF refl], insert as, unfold compute_nr_of_jordan_blocks_def, auto)
  have C: "?C = { (k,ev) | k. compute_nr_of_jordan_blocks A ev k  0  k  0  k < Suc kk }"
    unfolding C cards by force
  {
    fix k
    have "(k,ev)  ?C  (k,ev)  ?N"
    proof (cases "k = 0")
      case True
      thus ?thesis unfolding C by auto
    next
      case False
      show ?thesis
      proof (cases "k < Suc kk")
        case True
        have "length (filter ((=) (k, ev)) n_as)  0 
          set (filter ((=) (k, ev)) n_as)  {}" by blast
        have "(k,ev)  ?N   set (filter ((=) (k, ev)) n_as)  {}" using False by auto
        also have "  length (filter ((=) (k, ev)) n_as)  0" by blast
        also have "  compute_nr_of_jordan_blocks A ev k  0"
          unfolding compute_nr_of_jordan_blocks[OF jnf False] by simp
        also have "  (k,ev)  ?C" unfolding C using False True by auto
        finally show ?thesis by auto
      next
        case False
        hence "(k,ev)  ?C" unfolding C by auto
        moreover from False kk_def have k: "k > Polynomial.order ev (char_poly A)" by auto
        with jordan_nf_block_size_order_bound[OF jnf, of k ev]
        have "(k,ev)  ?N" by auto
        ultimately show ?thesis by simp
      qed
    qed
  }
  thus ?thesis unfolding C N[symmetric] by auto
qed

lemma jordan_nf_unique: assumes "jordan_nf (A :: 'a mat) n_as" and "jordan_nf A m_bs" 
shows "set n_as = set m_bs" 
proof -
  from compute_set_of_jordan_blocks[OF assms(1), unfolded compute_set_of_jordan_blocks[OF assms(2)]]
  show ?thesis by auto
qed

text ‹One might get more fine-grained and prove the uniqueness lemma for multisets, 
   so one takes multiplicities into account. For the moment we don't require this for
  complexity analysis, so it remains as future work.›

end

end

Theory Spectral_Radius

(*  
    Author:      René Thiemann 
                 Akihisa Yamada
    License:     BSD
*)
section ‹Spectral Radius Theory›

text ‹The following results show that the spectral radius characterize polynomial growth
  of matrix powers.›

theory Spectral_Radius
imports
  Jordan_Normal_Form_Existence
begin

definition "spectrum A = Collect (eigenvalue A)"

lemma spectrum_root_char_poly: assumes A: "(A :: 'a :: field mat)  carrier_mat n n"
  shows "spectrum A = {k. poly (char_poly A) k = 0}"
  unfolding spectrum_def eigenvalue_root_char_poly[OF A, symmetric] by auto

lemma card_finite_spectrum: assumes A: "(A :: 'a :: field mat)  carrier_mat n n"
  shows "finite (spectrum A)" "card (spectrum A)  n"
proof -
  define CP where "CP = char_poly A"
  from spectrum_root_char_poly[OF A] have id: "spectrum A = { k. poly CP k = 0}"
    unfolding CP_def by auto
  from degree_monic_char_poly[OF A] have d: "degree CP = n" and c: "coeff CP n = 1"
    unfolding CP_def by auto
  from c have "CP  0" by auto
  from poly_roots_finite[OF this]
  show "finite (spectrum A)" unfolding id .
  from poly_roots_degree[OF CP  0]
  show "card (spectrum A)  n" unfolding id using d by simp
qed

lemma spectrum_non_empty: assumes A: "(A :: complex mat)  carrier_mat n n"
  and n: "n > 0"
  shows "spectrum A  {}"
proof - 
  define CP where "CP = char_poly A"
  from spectrum_root_char_poly[OF A] have id: "spectrum A = { k. poly CP k = 0}"
    unfolding CP_def by auto
  from degree_monic_char_poly[OF A] have d: "degree CP > 0" using n
    unfolding CP_def by auto
  hence "¬ constant (poly CP)" by (simp add: constant_degree)
  from fundamental_theorem_of_algebra[OF this] show ?thesis unfolding id by auto
qed

definition spectral_radius :: "complex mat  real" where 
  "spectral_radius A = Max (norm ` spectrum A)"

lemma spectral_radius_mem_max: assumes A: "A  carrier_mat n n"
  and n: "n > 0"
  shows "spectral_radius A  norm ` spectrum A" (is ?one)
  "a  norm ` spectrum A  a  spectral_radius A"
proof -
  define SA where "SA = norm ` spectrum A"
  from card_finite_spectrum[OF A]
  have fin: "finite SA" unfolding SA_def by auto
  from spectrum_non_empty[OF A n] have ne: "SA  {}" unfolding SA_def by auto
  note d = spectral_radius_def SA_def[symmetric] Sup_fin_Max[symmetric]
  show ?one unfolding d
    by (rule Sup_fin.closed[OF fin ne], auto simp: sup_real_def)
  assume "a  norm ` spectrum A"
  thus "a  spectral_radius A" unfolding d
    by (intro Sup_fin.coboundedI[OF fin])
qed

text ‹If spectral radius is at most 1, and JNF exists, then we have polynomial growth.›

lemma spectral_radius_jnf_norm_bound_le_1: assumes A: "A  carrier_mat n n"
  and sr_1: "spectral_radius A  1"
  and jnf_exists: " n_as. jordan_nf A n_as"
  shows " c1 c2.  k. norm_bound (A ^m k) (c1 + c2 * of_nat k ^ (n - 1))"
proof -
  let ?p = "char_poly A"
  from char_poly_factorized[OF A] obtain as where cA: "char_poly A = (aas. [:- a, 1:])" 
    and len: "length as = n" by auto  
  show ?thesis
  proof (rule factored_char_poly_norm_bound[OF A cA jnf_exists])
    fix a
    show "length (filter ((=) a) as)  n" using len by auto
    assume "a  set as"
    from linear_poly_root[OF this]
    have "poly ?p a = 0" unfolding cA by simp
    with spectrum_root_char_poly[OF A] 
    have mem: "norm a  norm ` spectrum A" by auto
    with card_finite_spectrum[OF A] have "n > 0" by (cases n, auto)
    from spectral_radius_mem_max(2)[OF A this mem] sr_1 
    show "norm a  1" by auto
  qed
qed

text ‹If spectral radius is smaller than 1, and JNF exists, then we have a constant bound.›

lemma spectral_radius_jnf_norm_bound_less_1: assumes A: "A  carrier_mat n n"
  and sr_1: "spectral_radius A < 1"
  and jnf_exists: " n_as. jordan_nf A n_as" 
  shows " c.  k. norm_bound (A ^m k) c"
proof -
  let ?p = "char_poly A"
  from char_poly_factorized[OF A] obtain as where cA: "char_poly A = (aas. [:- a, 1:])" by auto
  have " c1 c2.  k. norm_bound (A ^m k) (c1 + c2 * of_nat k ^ (0 - 1))"
  proof (rule factored_char_poly_norm_bound[OF A cA jnf_exists])
    fix a
    assume "a  set as"
    from linear_poly_root[OF this]
    have "poly ?p a = 0" unfolding cA by simp
    with spectrum_root_char_poly[OF A] 
    have mem: "norm a  norm ` spectrum A" by auto
    with card_finite_spectrum[OF A] have "n > 0" by (cases n, auto)
    from spectral_radius_mem_max(2)[OF A this mem] sr_1 
    have lt: "norm a < 1" by auto
    thus "norm a  1" by auto
    from lt show "norm a = 1  length (filter ((=) a) as)  0" by auto
  qed
  thus ?thesis by auto
qed

text ‹If spectral radius is larger than 1, than we have exponential growth.›

lemma spectral_radius_gt_1: assumes A: "A  carrier_mat n n"
  and n: "n > 0"
  and sr_1: "spectral_radius A > 1"
  shows " v c. v  carrier_vec n  norm c > 1  v  0v n  A ^m k *v v = c^k v v"
proof -
  from sr_1 spectral_radius_mem_max[OF A n] obtain ev 
    where ev: "ev  spectrum A" and gt: "norm ev > 1" by auto
  from ev[unfolded spectrum_def eigenvalue_def[abs_def]] 
    obtain v where ev: "eigenvector A v ev" by auto
  from eigenvector_pow[OF A this] this[unfolded eigenvector_def] A gt
  show ?thesis
    by (intro exI[of _ v], intro exI[of _ ev], auto)
qed


text ‹If spectral radius is at most 1 for a complex matrix, then we have polynomial growth.›

lemma spectral_radius_jnf_norm_bound_le_1_upper_triangular: assumes A: "(A :: complex mat)  carrier_mat n n"
  and sr_1: "spectral_radius A  1"
  shows " c1 c2.  k. norm_bound (A ^m k) (c1 + c2 * of_nat k ^ (n - 1))"
  by (rule spectral_radius_jnf_norm_bound_le_1[OF A sr_1],
    insert char_poly_factorized[OF A] jordan_nf_exists[OF A], blast)

text ‹If spectral radius is less than 1 for a complex matrix, then we have a constant bound.›

lemma spectral_radius_jnf_norm_bound_less_1_upper_triangular: assumes A: "(A :: complex mat)  carrier_mat n n"
  and sr_1: "spectral_radius A < 1"
  shows " c.  k. norm_bound (A ^m k) c"
  by (rule spectral_radius_jnf_norm_bound_less_1[OF A sr_1],
    insert char_poly_factorized[OF A] jordan_nf_exists[OF A], blast)

text ‹And we can also get a quantative approximation via the multiplicity of the eigenvalues.›

lemma spectral_radius_poly_bound: fixes A :: "complex mat"
  assumes A: "A  carrier_mat n n" 
  and sr_1: "spectral_radius A  1"
  and eq_1: " ev k. poly (char_poly A) ev = 0  norm ev = 1  Polynomial.order ev (char_poly A)  d"
  shows " c1 c2.  k. norm_bound (A ^m k) (c1 + c2 * of_nat k ^ (d - 1))"
proof -
  {
    fix ev
    assume "poly (char_poly A) ev = 0"
    with eigenvalue_root_char_poly[OF A] have ev: "eigenvalue A ev" by simp
    hence "norm ev  norm ` spectrum A" unfolding spectrum_def by auto
    from spectral_radius_mem_max(2)[OF A eigenvalue_imp_nonzero_dim[OF A ev] this] sr_1    
    have "norm ev  1" by auto
  } note le_1 = this
  let ?p = "char_poly A"
  from char_poly_factorized[OF A] obtain as where cA: "char_poly A = (aas. [:- a, 1:])" 
    and lenn: "length as = n" by auto 
  from degree_monic_char_poly[OF A] have deg: "degree (char_poly A) = n" by auto
  show ?thesis
  proof (rule factored_char_poly_norm_bound[OF A cA jordan_nf_exists[OF A]], rule cA) 
    fix ev
    assume "ev  set as"
    hence root: "poly (char_poly A) ev = 0" unfolding cA by (rule linear_poly_root)
    from le_1[OF root] show "norm ev  1" .
    let ?k = "length (filter ((=) ev) as)"
    have len: "length (filter ((=) (- ev)) (map uminus as)) = length (filter ((=) ev) as)"
      by (induct as, auto)
    have prod: "(amap uminus as. [:a, 1:]) = (aas. [:- a, 1:])"
      by (induct as, auto)
    have dvd: "[:- ev, 1:] ^ ?k dvd char_poly A" unfolding cA using 
      poly_linear_exp_linear_factors_rev[of "- ev" "map uminus as"] 
      unfolding len prod .
    from ev  set as deg lenn
    have "degree (char_poly A)  0" by (cases as, auto)
    hence "char_poly A  0" by auto
    from order_max[OF dvd this] have k: "?k  Polynomial.order ev (char_poly A)" .
    assume "norm ev = 1"
    from eq_1[OF root this] k
    show "?k  d" by simp
  qed
qed

end

Theory DL_Missing_List

(* Author: Alexander Bentkamp, Universität des Saarlandes
*)
section ‹Missing Lemmas of List›

theory DL_Missing_List
imports Main
begin

lemma nth_map_zip:
assumes "i < length xs"
assumes "i < length ys"
shows "map f (zip xs ys) ! i = f (xs ! i, ys ! i)"
 using nth_zip nth_map length_zip by (simp add: assms(1) assms(2))

lemma nth_map_zip2:
assumes "i < length (map f (zip xs ys))"
shows "map f (zip xs ys) ! i = f (xs ! i, ys ! i)"
 using nth_zip nth_map length_zip assms by simp


fun find_first where
"find_first a [] = undefined" |
"find_first a (x # xs) = (if x = a then 0 else Suc (find_first a xs))"

lemma find_first_le:
assumes "a  set xs"
shows "find_first a xs < length xs"
using assms proof (induction xs)
  case (Cons x xs)
  then show ?case
    using find_first.simps(2) nth_Cons_0 nth_Cons_Suc set_ConsD by auto
qed auto

lemma nth_find_first:
assumes "a  set xs"
shows "xs ! (find_first a xs) = a"
using assms proof (induction xs)
  case (Cons x xs)
  then show ?case
    using find_first.simps(2) nth_Cons_0 nth_Cons_Suc set_ConsD by auto
qed auto

lemma find_first_unique:
assumes "distinct xs"
and "i < length xs"
shows "find_first (xs ! i) xs = i"
using assms proof (induction xs arbitrary:i)
  case (Cons x xs i)
  then show ?case by (cases i; auto)
qed auto

end

Theory DL_Rank

(* Author: Alexander Bentkamp, Universität des Saarlandes
*)
section ‹Matrix Rank›

theory DL_Rank
imports VS_Connect DL_Missing_List
 Determinant
 Missing_VectorSpace
begin

lemma (in vectorspace) full_dim_span:
assumes "S  carrier V"
and "finite S"
and "vectorspace.dim K (span_vs S) = card S"
shows "lin_indpt S"
proof -
  have "vectorspace K (span_vs S)"
    using field.field_axioms vectorspace_def submodule_is_module[OF span_is_submodule[OF assms(1)]] by metis
  have "S  carrier (span_vs S)"  by (simp add: assms(1) in_own_span)
  have "LinearCombinations.module.span K (vs (span S)) S = carrier (vs (span S))"
    using module.span_li_not_depend[OF _ span_is_submodule[OF assms(1)]]
    by (simp add: assms(1) in_own_span)
  have "vectorspace.basis K (vs (span S)) S"
    using vectorspace.dim_gen_is_basis[OF ‹vectorspace K (span_vs S) ‹finite S S  carrier (span_vs S)
     ‹LinearCombinations.module.span K (vs (span S)) S = carrier (vs (span S))]  ‹vectorspace.dim K (span_vs S) = card S
    by simp
  then have "LinearCombinations.module.lin_indpt K (vs (span S)) S"
    using vectorspace.basis_def[OF ‹vectorspace K (span_vs S)] by blast
  then show ?thesis using module.span_li_not_depend[OF _ span_is_submodule[OF assms(1)]]
    by (simp add: assms(1) in_own_span)
qed

lemma (in vectorspace) dim_span:
assumes "S  carrier V"
and "finite S"
and "maximal U (λT. T  S  lin_indpt T)"
shows "vectorspace.dim K (span_vs S) = card U"
proof -
  have "lin_indpt U" "U  S" by (metis assms(3) maximal_def)+
  then have "U  span S" using in_own_span[OF assms(1)] by blast
  then have lin_indpt: "LinearCombinations.module.lin_indpt K (span_vs S) U"
    using module.span_li_not_depend(2)[OF U  span S] ‹lin_indpt U assms(1) span_is_submodule by blast
  have "span U = span S"
  proof (rule ccontr)
    assume "span U  span S"
    have "span U  span S" using span_is_monotone US by metis
    then have "¬ S  span U" by (meson U  S ‹span U  span S assms(1) span_is_submodule
      span_is_subset subset_antisym subset_trans)
    then obtain s where "sS" "s  span U" by blast
    then have "lin_indpt (U{s})" using lindep_span
      by (meson U  S ‹lin_indpt U assms(1) lin_dep_iff_in_span rev_subsetD span_mem subset_trans)
    have "sU" using U  S s  span U assms(1) span_mem by auto
    then have "(U{s})  S  lin_indpt (U{s})" using U  S ‹lin_indpt (U  {s}) s  S by auto
    then have "¬maximal U (λT. T  S  lin_indpt T)"
      unfolding maximal_def using Un_subset_iff s  U insert_subset  order_refl by auto
    then show False using assms by metis
  qed
  then have span:"LinearCombinations.module.span K (vs (span S)) U = span S"
    using module.span_li_not_depend[OF U  span S]
    by (simp add: LinearCombinations.module.span_is_submodule assms(1) module_axioms)
  have "vectorspace K (vs (span S))"
    using field.field_axioms vectorspace_def submodule_is_module[OF span_is_submodule[OF assms(1)]] by metis
  then have "vectorspace.basis K (vs (span S)) U" using vectorspace.basis_def[OF ‹vectorspace K (vs (span S))]
    by (simp add: span U  span S lin_indpt)
  then show ?thesis
    using U  S ‹vectorspace K (vs (span S)) assms(2) infinite_super vectorspace.dim_basis by blast
qed

definition (in vec_space) rank ::"'a mat  nat"
where "rank A = vectorspace.dim class_ring (span_vs (set (cols A)))"

lemma (in vec_space) rank_card_indpt:
assumes "A  carrier_mat n nc"
assumes "maximal S (λT. T  set (cols A)  lin_indpt T)"
shows "rank A = card S"
proof -
  have "set (cols A)  carrier_vec n" using cols_dim assms(1) by blast
  have "finite (set (cols A))" by blast
  show ?thesis using dim_span[OF ‹set (cols A)  carrier_vec n ‹finite (set (cols A)) assms(2)]
    unfolding rank_def by blast
qed

lemma maximal_exists_superset:
  assumes "finite S"
  assumes maxc: "A. P A  A  S" and "P B"
  shows "A. finite A  maximal A P  B  A"
proof -
  have "finite (S-B)" using assms(1) assms(3) infinite_super maxc by blast
  then show ?thesis using P B
  proof (induction "S-B" arbitrary:B rule: finite_psubset_induct)
    case (psubset B)
    then show ?case
    proof (cases "maximal B P")
      case True
      then show ?thesis using order_refl psubset.hyps by (metis assms(1) maxc psubset.prems rev_finite_subset)
    next
      case False
      then obtain B' where "B  B'" "P B'" using maximal_def psubset.prems by (metis dual_order.order_iff_strict)
      then have "B'  S" "B  S" using maxc P B by auto
      then have "S - B'  S - B" using B  B' by blast
      then show ?thesis using psubset(2)[OF S - B'  S - B P B'] using B  B' by fast
    qed
  qed
qed

lemma (in vec_space) rank_ge_card_indpt:
assumes "A  carrier_mat n nc"
assumes "U  set (cols A)"
assumes "lin_indpt U"
shows "rank A  card U"
proof -
  obtain S where "maximal S (λT. T  set (cols A)  lin_indpt T)" "US" "finite S"
    using maximal_exists_superset[of "set (cols A)" "(λT. T  set (cols A)  lin_indpt T)" U]
    using List.finite_set assms(2) assms(3) maximal_exists_superset by blast
  then show ?thesis
    unfolding rank_card_indpt[OF A  carrier_mat n nc ‹maximal S (λT. T  set (cols A)  lin_indpt T)]
    using card_mono by blast
qed

lemma (in vec_space) lin_indpt_full_rank:
assumes "A  carrier_mat n nc"
assumes "distinct (cols A)"
assumes "lin_indpt (set (cols A))"
shows "rank A = nc"
proof -
  have "maximal (set (cols A)) (λT. T  set (cols A)  lin_indpt T)"
    by (simp add: assms(3) maximal_def subset_antisym)
  then have "rank A = card (set (cols A))" using assms(1) vec_space.rank_card_indpt by blast
  then show ?thesis using assms(1) assms(2) distinct_card by fastforce
qed

lemma (in vec_space) rank_le_nc:
assumes "A  carrier_mat n nc"
shows "rank A  nc"
proof -
  obtain S where "maximal S (λT. T  set (cols A)  lin_indpt T)"
    using maximal_exists[of "(λT. T  set (cols A)  lin_indpt T)" "card (set (cols A))" "{}"]
    by (meson List.finite_set card_mono empty_iff empty_subsetI finite_lin_indpt2 rev_finite_subset)
  then have "card S  card (set (cols A))" by (simp add: card_mono maximal_def)
  then have "card S  nc"
    using assms(1) cols_length card_length carrier_matD(2) by (metis dual_order.trans)
  then show ?thesis
    using rank_card_indpt[OF A  carrier_mat n nc ‹maximal S (λT. T  set (cols A)  lin_indpt T)]
    by simp
qed

lemma (in vec_space) full_rank_lin_indpt:
assumes "A  carrier_mat n nc"
assumes "rank A = nc"
assumes "distinct (cols A)"
shows "lin_indpt (set (cols A))"
proof -
  have 1:"set (cols A)  carrier_vec n" using assms(1) cols_dim by blast
  have 2:"finite (set (cols A))" by simp
  have "card (set (cols A)) = nc"
    using assms(1) assms(3) distinct_card by fastforce
  have 3:"vectorspace.dim class_ring (span_vs (set (cols A))) = card (set (cols A))"
    using ‹rank A = nc[unfolded rank_def]
    using assms(1) assms(3) distinct_card by fastforce
  show ?thesis using full_dim_span[OF 1 2 3] .
qed


lemma (in vec_space) mat_mult_eq_lincomb:
assumes "A  carrier_mat n nc"
assumes "distinct (cols A)"
shows "A *v (vec nc (λi. a (col A i))) = lincomb a (set (cols A))"
proof (rule eq_vecI)
  have "finite (set (cols A))" using assms(1) by simp
  then show "dim_vec (A *v (vec nc (λi. a (col A i)))) = dim_vec (lincomb a (set (cols A)))"
    using assms cols_dim vec_space.lincomb_dim by (metis dim_mult_mat_vec carrier_matD(1))
  fix i assume "i < dim_vec (lincomb a (set (cols A)))"
  then have "i < n" using ‹dim_vec (A *v (vec nc (λi. a (col A i)))) = dim_vec (lincomb a (set (cols A))) assms by auto
  have "set (cols A)  carrier_vec n" using cols_dim A  carrier_mat n nc carrier_matD(1) by blast
  have "bij_betw (nth (cols A)) {..<length (cols A)} (set (cols A))"
    unfolding bij_betw_def by (rule conjI, simp add: inj_on_nth ‹distinct (cols A);
    metis subset_antisym in_set_conv_nth lessThan_iff rev_image_eqI subsetI
    image_subsetI lessThan_iff nth_mem)
  then have " (xset (cols A). a x * x $ i) =
    (j{..<length (cols A)}. a (cols A ! j) * (cols A ! j) $ i)"
    using bij_betw_imp_surj_on bij_betw_imp_inj_on by (metis (no_types, lifting) sum.reindex_cong)
  also have "... = (j{..<length (cols A)}. a (col A j) * (cols A ! j) $ i)"
    using assms(1) assms(2) find_first_unique[OF ‹distinct (cols A)] i < n by auto
  also have "... = (j{..<length (cols A)}. (cols A ! j) $ i * a (col A j))" by (metis mult_commute_abs)
  also have "... = (j{..<length (cols A)}. row A i $ j * a (col A j))" using i < n assms(1) assms(2) by auto
  finally show "(A *v (vec nc (λi. a (col A i)))) $ i = lincomb a (set (cols A)) $ i"
    unfolding lincomb_index[OF i < n ‹set (cols A)  carrier_vec n]
    unfolding mult_mat_vec_def scalar_prod_def
    using i < n assms(1) atLeast0LessThan lessThan_def carrier_matD(1) index_vec sum.cong by auto
qed

lemma (in vec_space) lincomb_eq_mat_mult:
assumes "A  carrier_mat n nc"
assumes "v  carrier_vec nc"
assumes "distinct (cols A)"
shows "lincomb (λa. v $ find_first a (cols A)) (set (cols A)) = (A *v v)"
proof -
  have "i. i < nc  find_first (col A i) (cols A) = i"
    using assms(1) assms(3) find_first_unique by fastforce
  then have "vec nc (λi. v $ find_first (col A i) (cols A)) = v"
    using assms(2) by auto
  then show ?thesis
    using mat_mult_eq_lincomb[where a = "(λa. v $ find_first a (cols A))", OF assms(1) assms(3)] by auto
qed

lemma (in vec_space) lin_depI:
assumes "A  carrier_mat n nc"
assumes "v  carrier_vec nc" "v  0v nc" "A *v v = 0v n"
assumes "distinct (cols A)"
shows "lin_dep (set (cols A))"
proof -
  have 1: "finite (set (cols A))" by simp
  have 2: "set (cols A)  set (cols A)" by auto
  have 3: "(λa. v $ find_first a (cols A))  set (cols A)  UNIV" by simp
  obtain i where "v $ i  0" "i < nc"
    using v  0v nc
    by (metis assms(2) dim_vec carrier_vecD vec_eq_iff zero_vec_def index_zero_vec(1))
  then have "i < dim_col A" using assms(1) by blast
  have 4:"col A i  set (cols A)"
    using cols_nth[OF i < dim_col A] i < dim_col A in_set_conv_nth by fastforce
  have 5:"v $ find_first (col A i) (cols A)  0"
    using find_first_unique[OF ‹distinct (cols A)] cols_nth[OF i < dim_col A] i < nc v $ i  0
    assms(1) by auto
  have 6:"lincomb (λa. v $ find_first a (cols A)) (set (cols A)) = 0v n"
    using assms(1) assms(2) assms(4) assms(5) lincomb_eq_mat_mult by auto
  show ?thesis using lin_dep_crit[OF 1 2 _ 4 5 6] by metis
qed

lemma (in vec_space) lin_depE:
assumes "A  carrier_mat n nc"
assumes "lin_dep (set (cols A))"
assumes "distinct (cols A)"
obtains v where "v  carrier_vec nc" "v  0v nc" "A *v v = 0v n"
proof -
  have "finite (set (cols A))" by simp
  obtain a w where "a  set (cols A)  UNIV" "lincomb a (set (cols A)) = 0v n" "w  set (cols A)" "a w  0"
    using finite_lin_dep[OF ‹finite (set (cols A)) ‹lin_dep (set (cols A))]
    using assms(1) cols_dim carrier_matD(1) by blast
  define v where "v = vec nc (λi. a (col A i))"
  have 1:"v  carrier_vec nc" by (simp add: v_def)
  have 2:"v  0v nc"
  proof -
    obtain i where "w = col A i" "i < length (cols A)"
      by (metis w  set (cols A) cols_length cols_nth in_set_conv_nth)
    have "v $ i  0"
      unfolding v_def
      using a w  0[unfolded w = col A i] index_vec[OF i < length (cols A)]
      assms(1) cols_length carrier_matD(2) by (metis (no_types) A  carrier_mat n nc
      f. vec (length (cols A)) f $ i = f i a (col A i)  0 cols_length carrier_matD(2))
    then show ?thesis using i < length (cols A) assms(1) by auto
  qed
  have 3:"A *v v = 0v n" unfolding v_def
    using ‹lincomb a (set (cols A)) = 0v n mat_mult_eq_lincomb[OF A  carrier_mat n nc ‹distinct (cols A)] by auto
  show thesis using 1 2 3 by (simp add: that)
qed

lemma (in vec_space) non_distinct_low_rank:
assumes "A  carrier_mat n n"
and "¬ distinct (cols A)"
shows "rank A < n"
proof -
  obtain S where "maximal S (λT. T  set (cols A)  lin_indpt T)"
    using maximal_exists[of "(λT. T  set (cols A)  lin_indpt T)" "card (set (cols A))" "{}"]
    by (meson List.finite_set card_mono empty_iff empty_subsetI finite_lin_indpt2 rev_finite_subset)
  then have "card S  card (set (cols A))" by (simp add: card_mono maximal_def)
  then have "card S < n"
    using assms(1) cols_length card_length ¬ distinct (cols A) card_distinct carrier_matD(2) nat_less_le
    by (metis dual_order.antisym dual_order.trans)
  then show ?thesis
    using rank_card_indpt[OF A  carrier_mat n n ‹maximal S (λT. T  set (cols A)  lin_indpt T)]
    by simp
qed

text ‹The theorem "det non-zero $\longleftrightarrow$ full rank" is practically proven in det\_0\_iff\_vec\_prod\_zero\_field,
but without an actual definition of the rank.›

lemma (in vec_space) det_zero_low_rank:
assumes "A  carrier_mat n n"
and "det A = 0"
shows "rank A < n"
proof (rule ccontr)
  assume "¬ rank A < n"
  then have "rank A = n" using rank_le_nc assms le_neq_implies_less by blast
  obtain v where "v  carrier_vec n" "v  0v n" "A *v v = 0v n"
    using det_0_iff_vec_prod_zero_field[OF assms(1)] assms(2) by blast
  then show False
  proof (cases "distinct (cols A)")
    case True
    then have "lin_indpt (set (cols A))" using full_rank_lin_indpt using ‹rank A = n assms(1) by auto
    then show False using lin_depI[OF assms(1) v  carrier_vec n v  0v n A *v v = 0v n] True by blast
  next
    case False
    then show False using non_distinct_low_rank ‹rank A = n ¬ rank A < n assms(1) by blast
  qed
qed

lemma det_identical_cols:
  assumes A: "A  carrier_mat n n"
    and ij: "i  j"
    and i: "i < n" and j: "j < n"
    and r: "col A i = col A j"
  shows "det A = 0"
  using det_identical_rows det_transpose
  by (metis A i ij j carrier_matD(2) transpose_carrier_mat r row_transpose)

lemma (in vec_space) low_rank_det_zero:
assumes "A  carrier_mat n n"
and "det A  0"
shows "rank A = n"
proof -
  have "distinct (cols A)"
  proof (rule ccontr)
    assume "¬ distinct (cols A)"
    then obtain i j where "ij" "(cols A) ! i = (cols A) ! j" "i<length (cols A)" "j<length (cols A)"
      using distinct_conv_nth by blast
    then have "col A i = col A j" "i<n" "j<n" using assms(1) by auto
    then have "det A = 0"  using det_identical_cols using i  j assms(1) by blast
    then show False using ‹det A  0 by auto
  qed
  have "v. v  carrier_vec n  v  0v n  A *v v  0v n"
    using det_0_iff_vec_prod_zero_field[OF assms(1)] assms(2) by auto
  then have "lin_indpt (set (cols A))" using lin_depE[OF assms(1) _ ‹distinct (cols A)] by auto
  then show ?thesis using lin_indpt_full_rank[OF assms(1) ‹distinct (cols A)] by metis
qed

lemma (in vec_space) det_rank_iff:
assumes "A  carrier_mat n n"
shows "det A  0  rank A = n"
  using assms det_zero_low_rank low_rank_det_zero by force

section "Subadditivity of rank"

text ‹Subadditivity is the property of rank, that rank (A + B) <= rank A + rank B.›

lemma (in Module.module) lincomb_add:
assumes "finite (b1  b2)"
assumes "b1  b2  carrier M"
assumes "x1 = lincomb a1 b1" "a1 (b1carrier R)"
assumes "x2 = lincomb a2 b2" "a2 (b2carrier R)"
assumes "x = x1 M x2"
shows "lincomb (λv. (λv. if v  b1 then a1 v else 𝟬) v  (λv. if v  b2 then a2 v else 𝟬) v) (b1  b2) = x"
proof -
  have "finite (b1  (b2-b1))" "finite (b2  (b1-b2))"
       "b1  (b2 - b1)  carrier M" "b2  (b1-b2)  carrier M"
       "b1  (b2 - b1) = {}" "b2  (b1 - b2) = {}"
       "(λb. 𝟬R)  b2 - b1  carrier R" "(λb. 𝟬R)  b1 - b2  carrier R"
    using ‹finite (b1  b2) b1  b2  carrier M a2 (b2carrier R) by auto
  have "lincomb (λb. 𝟬R) (b2 - b1) = 𝟬M" "lincomb (λb. 𝟬R) (b1 - b2) = 𝟬M"
    unfolding lincomb_def using M.finsum_all0 assms(2) lmult_0 subset_iff
    by (metis (no_types, lifting) Un_Diff_cancel2 inf_sup_aci(5) le_sup_iff)+
  then have "x1 = lincomb (λv. if v  b1 then a1 v else 𝟬) (b1  b2)"
            "x2 = lincomb (λv. if v  b2 then a2 v else 𝟬) (b1  b2)"
    using lincomb_union2[OF ‹finite (b1  (b2-b1)) b1  (b2 - b1)  carrier M b1  (b2 - b1) = {} a1 (b1carrier R) (λb. 𝟬R)  b2 - b1  carrier R]
          lincomb_union2[OF ‹finite (b2  (b1-b2)) b2  (b1-b2)  carrier M b2  (b1 - b2) = {} a2 (b2carrier R) (λb. 𝟬R)  b1 - b2  carrier R]
    using assms(2) assms(3) assms(4)  assms(5)  assms(6) by (simp_all add:Un_commute)
  have "(λv. if v  b1 then a1 v else 𝟬)  (b1  b2)  carrier R"
       "(λv. if v  b2 then a2 v else 𝟬)  (b1  b2)  carrier R" using assms(4) assms(6) by auto
  show "lincomb (λv. (λv. if v  b1 then a1 v else 𝟬) v  (λv. if v  b2 then a2 v else 𝟬) v) (b1  b2) = x"
    using lincomb_sum[OF ‹finite (b1  b2) b1  b2  carrier M
    (λv. if v  b1 then a1 v else 𝟬)  (b1  b2)  carrier R (λv. if v  b2 then a2 v else 𝟬)  (b1  b2)  carrier R]
    x1 = lincomb (λv. if v  b1 then a1 v else 𝟬) (b1  b2) x2 = lincomb (λv. if v  b2 then a2 v else 𝟬) (b1  b2) assms(7) by blast
qed

lemma (in vectorspace) dim_subadditive:
assumes "subspace K W1 V"
and "vectorspace.fin_dim K (vs W1)"
assumes "subspace K W2 V"
and "vectorspace.fin_dim K (vs W2)"
shows "vectorspace.dim K (vs (subspace_sum W1 W2))  vectorspace.dim K (vs W1) + vectorspace.dim K (vs W2)"
proof -
  have "vectorspace K (vs W1)" "vectorspace K (vs W2)" "submodule K W1 V" "submodule K W2 V"
    by (simp add: ‹subspace K W1 V ‹subspace K W2 V subspace_is_vs)+
  obtain b1 b2 where "vectorspace.basis K (vs W1) b1" "vectorspace.basis K (vs W2) b2" "finite b1" "finite b2"
    using vectorspace.finite_basis_exists[OF ‹vectorspace K (vs W1) ‹vectorspace.fin_dim K (vs W1)]
    using vectorspace.finite_basis_exists[OF ‹vectorspace K (vs W2) ‹vectorspace.fin_dim K (vs W2)]
    by blast
  then have "LinearCombinations.module.gen_set K (vs W1) b1" "LinearCombinations.module.gen_set K (vs W2) b2"
    using ‹vectorspace K (vs W1) ‹vectorspace K (vs W2) vectorspace.basis_def by blast+
  then have "span b1 = W1" "span b2 = W2"
    using module.span_li_not_depend(1) ‹submodule K W1 V  ‹submodule K W2 V
    ‹vectorspace K (vs W1) ‹vectorspace.basis K (vs W1) b1 ‹vectorspace K (vs W2)
    ‹vectorspace.basis K (vs W2) b2 vectorspace.basis_def by force+
  have "W1  carrier V" "W2  carrier V" using ‹subspace K W1 V ‹subspace K W2 V subspace_def submodule_def by metis+
  have "b1  carrier V"
    using ‹vectorspace.basis K (vs W1) b1 ‹vectorspace K (vs W1) vectorspace.basis_def
    W1  carrier V by fastforce
  have "b2  carrier V"
    using ‹vectorspace.basis K (vs W2) b2 ‹vectorspace K (vs W2) vectorspace.basis_def
    W2  carrier V by fastforce
  have "finite (b1  b2)" "b1  b2  carrier V"
    by (simp_all add: ‹finite b1 ‹finite b2 b2  carrier V b1  carrier V)
  have "subspace_sum W1 W2  span (b1b2)"
  proof (rule subsetI)
    fix x assume "x  subspace_sum W1 W2"
    obtain x1 x2 where  "x1  W1" "x2  W2" "x = x1 V x2"
      using imageE[OF x  subspace_sum W1 W2[unfolded submodule_sum_def]]
      by (metis (no_types, lifting) BNF_Def.Collect_case_prodD split_def)
    obtain a1 where "x1 = lincomb a1 b1" "a1 (b1carrier K)"
      using ‹span b1 = W1 finite_span[OF ‹finite b1 b1  carrier V] x1  W1 by auto
    obtain a2 where "x2 = lincomb a2 b2" "a2 (b2carrier K)"
      using ‹span b2 = W2 finite_span[OF ‹finite b2 b2  carrier V] x2  W2 by auto
    obtain a where "x = lincomb a (b1  b2)" using lincomb_add[OF ‹finite (b1  b2) b1  b2  carrier V
      x1 = lincomb a1 b1 a1 (b1carrier K)  x2 = lincomb a2 b2 a2 (b2carrier K) x = x1 V x2] by blast
    then show "x  span (b1  b2)" using finite_span[OF ‹finite (b1  b2) (b1  b2)  carrier V]
      using b1  carrier V b2  carrier V ‹span b1 = W1 ‹span b2 = W2 x  subspace_sum W1 W2 span_union_is_sum by auto
  qed
  have "b1  W1" "b2  W2"
    using ‹vectorspace K (vs W1) ‹vectorspace K (vs W2) ‹vectorspace.basis K (vs W1) b1
    ‹vectorspace.basis K (vs W2) b2 vectorspace.basis_def local.carrier_vs_is_self by blast+
  then have "b1b2  subspace_sum W1 W2" using ‹submodule K W1 V ‹submodule K W2 V in_sum
    by (metis assms(1) assms(3) dual_order.trans sup_least vectorspace.vsum_comm vectorspace_axioms)
  have "subspace_sum W1 W2 = LinearCombinations.module.span K (vs (subspace_sum W1 W2)) (b1b2)"
  proof (rule subset_antisym)
    have "submodule K (subspace_sum W1 W2) V" by (simp add: ‹submodule K W1 V ‹submodule K W2 V sum_is_submodule)
    show "subspace_sum W1 W2  LinearCombinations.module.span K (vs (subspace_sum W1 W2)) (b1b2)"
      using module.span_li_not_depend(1)[OF b1b2  subspace_sum W1 W2 ‹submodule K (subspace_sum W1 W2) V]
      by (simp add: ‹subspace_sum W1 W2  span (b1  b2))
    show "subspace_sum W1 W2  LinearCombinations.module.span K (vs (subspace_sum W1 W2)) (b1b2)"
      using b1b2  subspace_sum W1 W2 by (metis (full_types) LinearCombinations.module.span_is_subset2
      LinearCombinations.module.submodule_is_module ‹submodule K (subspace_sum W1 W2) V local.carrier_vs_is_self submodule_def)
  qed
  have "vectorspace K (vs (subspace_sum W1 W2))" using assms(1) assms(3) subspace_def sum_is_subspace vectorspace.subspace_is_vs by blast
  then have "vectorspace.dim K (vs (subspace_sum W1 W2))  card (b1  b2)"
    using vectorspace.gen_ge_dim[OF ‹vectorspace K (vs (subspace_sum W1 W2)) ‹finite (b1  b2)]
    b1  b2  subspace_sum W1 W2
    ‹subspace_sum W1 W2 = LinearCombinations.module.span K (vs (subspace_sum W1 W2)) (b1  b2)
    local.carrier_vs_is_self by blast
  also have "...  card b1 + card b2" by (simp add: card_Un_le)
  also have "... = vectorspace.dim K (vs W1) + vectorspace.dim K (vs W2)"
    by (metis ‹finite b1 ‹finite b2 ‹vectorspace K (vs W1) ‹vectorspace K (vs W2)
    ‹vectorspace.basis K (vs W1) b1 ‹vectorspace.basis K (vs W2) b2 vectorspace.dim_basis)
  finally show ?thesis by auto
qed

lemma (in Module.module) nested_submodules:
assumes "submodule R W M"
assumes "submodule R X M"
assumes "X  W"
shows "submodule R X (md W)"
  unfolding submodule_def
  using X  W submodule_is_module[OF ‹submodule R W M] using ‹submodule R X M[unfolded submodule_def] by auto

lemma (in vectorspace) nested_subspaces:
assumes "subspace K W V"
assumes "subspace K X V"
assumes "X  W"
shows "subspace K X (vs W)"
  using assms nested_submodules subspace_def subspace_is_vs by blast

lemma (in vectorspace) subspace_dim:
assumes "subspace K X V" "fin_dim" "vectorspace.fin_dim K (vs X)"
shows "vectorspace.dim K (vs X)  dim"
proof -
  have "vectorspace K (vs X)" using assms(1) subspace_is_vs by auto
  then obtain b where "vectorspace.basis K (vs X) b" using vectorspace.finite_basis_exists
    using assms(3) by blast
  then have "b  carrier V" "LinearCombinations.module.lin_indpt K (vs X) b"
    using vectorspace.basis_def[OF ‹vectorspace K (vs X)] ‹subspace K X V[unfolded subspace_def submodule_def] by auto
  then have "lin_indpt b"
    by (metis LinearCombinations.module.span_li_not_depend(2) ‹vectorspace K (vs X) ‹vectorspace.basis K (vs X) b
    assms(1) is_module local.carrier_vs_is_self submodule_def vectorspace.basis_def)
  show ?thesis using li_le_dim(2)[OF ‹fin_dim› b  carrier V ‹lin_indpt b]
    using b  carrier V ‹lin_indpt b ‹vectorspace K (vs X) ‹vectorspace.basis K (vs X) b assms(2)
    fin_dim_li_fin vectorspace.dim_basis by fastforce
qed

lemma (in vectorspace) fin_dim_subspace_sum:
assumes "subspace K W1 V"
assumes "subspace K W2 V"
assumes "vectorspace.fin_dim K (vs W1)" "vectorspace.fin_dim K (vs W2)"
shows "vectorspace.fin_dim K (vs (subspace_sum W1 W2))"
proof -
  obtain b1 where "finite b1" "b1  W1" "LinearCombinations.module.gen_set K (vs W1) b1"
    using assms vectorspace.fin_dim_def subspace_is_vs by force
  obtain b2 where "finite b2" "b2  W2" "LinearCombinations.module.gen_set K (vs W2) b2"
    using assms vectorspace.fin_dim_def subspace_is_vs by force
  have 1:"finite (b1  b2)" by (simp add: ‹finite b1 ‹finite b2)
  have 2:"b1  b2  subspace_sum W1 W2"
    by (metis (no_types, lifting) b1  W1 b2  W2 assms(1) assms(2)
    le_sup_iff subset_Un_eq vectorspace.in_sum_vs vectorspace.vsum_comm vectorspace_axioms)
  have 3:"LinearCombinations.module.gen_set K (vs (subspace_sum W1 W2)) (b1  b2)"
  proof (rule subset_antisym)
    have 0:"LinearCombinations.module.span K (vs (subspace_sum W1 W2)) (b1  b2) = span (b1  b2)"
      using span_li_not_depend(1)[OF b1  b2  subspace_sum W1 W2] sum_is_subspace[OF assms(1) assms(2)] by auto
    then show "LinearCombinations.module.span K (vs (subspace_sum W1 W2)) (b1  b2)  carrier (vs (subspace_sum W1 W2))"
      using b1  b2  subspace_sum W1 W2 span_is_subset sum_is_subspace[OF assms(1) assms(2)] by auto
    show "carrier (vs (subspace_sum W1 W2))  LinearCombinations.module.span K (vs (subspace_sum W1 W2)) (b1  b2)"
     unfolding 0
    proof
      fix x assume assumption:"x  carrier (vs (subspace_sum W1 W2))"
      then have "xsubspace_sum W1 W2" by auto
      then obtain x1 x2 where "x = x1 V x2" "x1W1" "x2W2"
        using imageE[OF x  subspace_sum W1 W2[unfolded submodule_sum_def]]
        by (metis (no_types, lifting) BNF_Def.Collect_case_prodD split_def)
      have "x1span b1"  "x2span b2"
        using ‹LinearCombinations.module.span K (vs W1) b1 = carrier (vs W1) b1  W1 x1  W1
              ‹LinearCombinations.module.span K (vs W2) b2 = carrier (vs W2) b2  W2 x2  W2
        assms(1) assms(2) span_li_not_depend(1) by auto
      then have "x1span (b1  b2)" "x2span (b1  b2)" by (meson le_sup_iff subsetD span_is_monotone subsetI)+
      then show "x  span (b1  b2)" unfolding x = x1 V x2
        by (meson b1  b2  subspace_sum W1 W2 assms(1) assms(2) is_module submodule.subset
        subset_trans sum_is_submodule vectorspace.span_add1 vectorspace_axioms)
    qed
  qed
  show ?thesis using 1 2 3 vectorspace.fin_dim_def
    by (metis assms(1) assms(2) local.carrier_vs_is_self subspace_def sum_is_subspace vectorspace.subspace_is_vs)
qed

lemma (in vec_space) rank_subadditive:
assumes "A  carrier_mat n nc"
assumes "B  carrier_mat n nc"
shows "rank (A + B)  rank A + rank B"
proof -
  define W1 where "W1 = span (set (cols A))"
  define W2 where "W2 = span (set (cols B))"
  have "set (cols (A + B))  subspace_sum W1 W2"
  proof
    fix x assume "x  set (cols (A + B))"
    obtain i where "x = col (A + B) i" "i < length (cols (A + B))"
      using x  set (cols (A + B)) nth_find_first cols_nth find_first_le by (metis cols_length)
    then have "x = col A i + col B i" using i < length (cols (A + B)) assms(1) assms(2) by auto
    have "col A i  span (set (cols A))" "col B i  span (set (cols B))"
      using i < length (cols (A + B)) assms(1) assms(2) in_set_conv_nth
      by (metis cols_dim cols_length cols_nth carrier_matD(1) carrier_matD(2) index_add_mat(3) span_mem)+
    then show "x  subspace_sum W1 W2"
      unfolding W1_def W2_def x = col A i + col B i submodule_sum_def by blast
  qed
  have "subspace class_ring (subspace_sum W1 W2) V"
    by (metis W1_def W2_def assms(1) assms(2) cols_dim carrier_matD(1) span_is_submodule subspace_def sum_is_submodule vec_vs)
  then have "span (set (cols (A + B)))  subspace_sum W1 W2"
    by (simp add: ‹set (cols (A + B))  subspace_sum W1 W2 span_is_subset)
  have "subspace class_ring (span (set (cols (A + B)))) V" by (metis assms(2) cols_dim add_carrier_mat carrier_matD(1) span_is_subspace)
  have subspace:"subspace class_ring (span (set (cols (A + B)))) (vs (subspace_sum W1 W2))"
    using nested_subspaces[OF ‹subspace class_ring (subspace_sum W1 W2) V› ‹subspace class_ring (span (set (cols (A + B)))) V›
    ‹span (set (cols (A + B)))  subspace_sum W1 W2] .
  have "vectorspace.fin_dim class_ring (vs W1)" "vectorspace.fin_dim class_ring (vs W2)"
       "subspace class_ring W1 V" "subspace class_ring W2 V"
    using span_is_subspace W1_def W2_def assms(1) assms(2) cols_dim carrier_matD fin_dim_span_cols by auto
  then have fin_dim: "vectorspace.fin_dim class_ring (vs (subspace_sum W1 W2))" using fin_dim_subspace_sum by auto
  have "vectorspace.fin_dim class_ring (span_vs (set (cols (A + B))))" using assms(2) add_carrier_mat vec_space.fin_dim_span_cols by blast
  then have "rank (A + B)  vectorspace.dim class_ring (vs (subspace_sum W1 W2))" unfolding rank_def
    using vectorspace.subspace_dim[OF subspace_is_vs[OF ‹subspace class_ring (subspace_sum W1 W2) V›] subspace fin_dim] by auto
  also have "vectorspace.dim class_ring (vs (subspace_sum W1 W2))  rank A + rank B" unfolding rank_def
    using W1_def W2_def ‹subspace class_ring W1 V› ‹subspace class_ring W2 V› ‹vectorspace.fin_dim class_ring (vs W1)
    ‹vectorspace.fin_dim class_ring (vs W2) subspace_def vectorspace.dim_subadditive by blast
  finally show ?thesis by auto
qed

lemma (in vec_space) span_zero: "span {zero V} = {zero V}"
  by (metis (no_types, lifting) empty_subsetI in_own_span span_is_submodule span_is_subset
  span_is_subset2 subset_antisym vectorspace.span_empty vectorspace_axioms)

lemma (in vec_space) dim_zero_vs: "vectorspace.dim class_ring (span_vs {}) = 0"
proof -
  have "vectorspace class_ring (span_vs {})" using field.field_axioms span_is_submodule submodule_is_module vectorspace_def by auto
  have "{}  carrier_vec n  lin_indpt {}"
    by (metis (no_types) empty_subsetI fin_dim finite_basis_exists subset_li_is_li vec_vs vectorspace.basis_def)
  then have "vectorspace.basis class_ring (span_vs {}) {}" using vectorspace.basis_def
    by (simp add: ‹vectorspace class_ring (vs (span {})) span_is_submodule span_li_not_depend(1) span_li_not_depend(2) vectorspace.basis_def)
  then show ?thesis using ‹vectorspace class_ring (vs (span {})) vectorspace.dim_basis by fastforce
qed

lemma (in vec_space) rank_0I: "rank (0m n nc) = 0"
proof -
  have "set (cols (0m n nc))  {0v n}"
    by (metis col_zero cols_length cols_nth in_set_conv_nth insertCI index_zero_mat(3) subsetI)
  have "set (cols (0m n nc::'a mat)) = {}  set (cols (0m n nc)) = {0v n::'a vec}"
    by (meson ‹set (cols (0m n nc))  {0v n} subset_singletonD)
  then have "span (set (cols (0m n nc))) = {0v n}"
    by (metis (no_types) span_empty span_zero vectorspace.span_empty vectorspace_axioms)
  then show ?thesis unfolding rank_def ‹span (set (cols (0m n nc))) = {0v n}
    using span_empty dim_zero_vs by simp
qed


lemma (in vec_space) rank_le_1_product_entries:
fixes f g::"nat  'a"
assumes "A  carrier_mat n nc"
assumes "r c. r<dim_row A  c<dim_col A  A $$ (r,c) = f r * g c"
shows "rank A  1"
proof -
  have "set (cols A)  span {vec n f}"
  proof
    fix v assume "v  set (cols A)"
    then obtain c where "c < dim_col A" "v = col A c" by (metis cols_length cols_nth in_set_conv_nth)
    have "g c v vec n f = v"
    proof (rule eq_vecI)
      show "dim_vec (g c v Matrix.vec n f) = dim_vec v" using v = col A c assms(1) by auto
      fix r assume "r < dim_vec v"
      then have "r < dim_vec (Matrix.vec n f)" using ‹dim_vec (g c v Matrix.vec n f) = dim_vec v by auto
      then have "r < n" "r < dim_row A"using index_smult_vec(2) A  carrier_mat n nc by auto
      show "(g c v Matrix.vec n f) $ r = v $ r"
        unfolding v = col A c col_def index_smult_vec(1)[OF r < dim_vec (Matrix.vec n f)]
        index_vec[OF r < n] index_vec[OF r < dim_row A] by (simp add: c < dim_col A r < dim_row A assms(2))
    qed
    then show "v  span {vec n f}" using submodule.smult_closed[OF span_is_submodule]
      using UNIV_I empty_subsetI insert_subset span_self dim_vec module_vec_simps(4) by auto
  qed
  have "vectorspace class_ring (vs (span {Matrix.vec n f}))" using span_is_subspace[THEN subspace_is_vs, of "{vec n f}"] by auto
  have "submodule class_ring (span {Matrix.vec n f}) V" by (simp add: span_is_submodule)
  have "subspace class_ring(span (set (cols A))) (vs (span {Matrix.vec n f}))"
    using vectorspace.span_is_subspace[OF ‹vectorspace class_ring (vs (span {Matrix.vec n f})), of "set (cols A)", unfolded
    span_li_not_depend(1)[OF ‹set (cols A)  span {vec n f} ‹submodule class_ring (span {Matrix.vec n f}) V›]]
    ‹set (cols A)  span {vec n f} by auto
  have fin_dim:"vectorspace.fin_dim class_ring (vs (span {Matrix.vec n f}))"
       "vectorspace.fin_dim class_ring (vs (span {Matrix.vec n f})carrier := span (set (cols A)))"
    using fin_dim_span fin_dim_span_cols A  carrier_mat n nc by auto
  have "vectorspace.dim class_ring (vs (span {Matrix.vec n f}))  1"
    using vectorspace.dim_le1I[OF ‹vectorspace class_ring (vs (span {Matrix.vec n f}))]
    span_mem span_li_not_depend(1)[OF _ ‹submodule class_ring (span {Matrix.vec n f}) V›] by simp
  then show ?thesis unfolding rank_def using  "vectorspace.subspace_dim"[OF
    ‹vectorspace class_ring (vs (span {Matrix.vec n f})) ‹subspace class_ring (span (set (cols A))) (vs (span {Matrix.vec n f}))
    fin_dim(1) fin_dim(2)] by simp
qed

end

Theory DL_Missing_Sublist

(* Author: Alexander Bentkamp, Universität des Saarlandes
*)
section ‹Missing Lemmas of Sublist›

theory DL_Missing_Sublist
imports Main
begin

lemma nths_only_one:
assumes "{i. i < length xs  iI} = {j}"
shows "nths xs I = [xs!j]"
proof -
  have "set (nths xs I) = {xs!j}"
    unfolding set_nths using subset_antisym assms by fastforce
  moreover have "length (nths xs I) = 1"
    unfolding length_nths assms by auto
  ultimately show ?thesis
    by (metis One_nat_def length_0_conv length_Suc_conv the_elem_eq the_elem_set)
qed

lemma nths_replicate:
"nths (replicate n x) A = (replicate (card {i. i < n  i  A}) x)"
proof (induction n)
  case 0
  then show ?case by simp
next
  case (Suc n)
  then show ?case
  proof (cases "nA")
    case True
    then have 0:"(if 0  {j. j + length (replicate n x)  A} then [x] else []) = [x]" by simp
    have "{i. i < Suc n  i  A} = insert n {i. i < n  i  A}" using True by auto
    have "Suc (card {i. i < n  i  A}) = card {i. i < Suc n  i  A}"
      unfolding {i. i < Suc n  i  A} = insert n {i. i < n  i  A}
      using finite_Collect_conjI[THEN card_insert_if] finite_Collect_less_nat
       less_irrefl_nat mem_Collect_eq by simp
    then show ?thesis unfolding replicate_Suc replicate_append_same[symmetric] nths_append Suc nths_singleton 0
      unfolding replicate_append_same replicate_Suc[symmetric] by simp
  next
    case False
    then have 0:"(if 0  {j. j + length (replicate n x)  A} then [x] else []) = []" by simp
    have "{i. i < Suc n  i  A} = {i. i < n  i  A}" using False using le_less less_Suc_eq_le by auto
    then show ?thesis unfolding replicate_Suc replicate_append_same[symmetric] nths_append Suc nths_singleton 0
      by simp
  qed
qed

lemma length_nths_even:
assumes "even (length xs)"
shows "length (nths xs (Collect even)) = length (nths xs (Collect odd))"
using assms proof (induction "length xs div 2" arbitrary:xs)
  case 0
  then have "length xs = 0"
    by (auto elim: evenE)
  then show ?case by simp
next
  case (Suc l xs)
  then have length_drop2: "length (nths (drop 2 xs) (Collect even)) = length (nths (drop 2 xs) {a. odd a})" by simp

  have "length (take 2 xs) = 2" using Suc.hyps(2) by auto
  then have plus_odd: "{j. j + length (take 2 xs)  Collect odd} = Collect odd" and
            plus_even: "{j. j + length (take 2 xs)  Collect even} = Collect even" by simp_all
  have nths_take2: "nths (take 2 xs) (Collect even) = [take 2 xs ! 0]" "nths (take 2 xs) (Collect odd) = [take 2 xs ! 1]"
    using ‹length (take 2 xs) = 2 less_2_cases nths_only_one[of "take 2 xs" "Collect even" 0]
    nths_only_one[of "take 2 xs" "Collect odd" 1]
    by fastforce+
  then have "length (nths (take 2 xs @ drop 2 xs) (Collect even))
           = length (nths (take 2 xs @ drop 2 xs) {a. odd a})"
    unfolding nths_append length_append plus_odd plus_even nths_take2 length_drop2
    by auto
  then show ?case using append_take_drop_id[of 2 xs] by simp
qed

lemma nths_map:
"nths (map f xs) A = map f (nths xs A)"
proof (induction xs arbitrary:A)
  case Nil
  then show ?case by simp
next
  case (Cons x xs)
  then show ?case
  by (simp add: nths_Cons)
qed


section "Pick"

fun pick :: "nat set  nat  nat" where
"pick S 0 = (LEAST a. aS)" |
"pick S (Suc n) = (LEAST a. aS  a > pick S n)"

lemma pick_in_set_inf:
assumes "infinite S"
shows "pick S n  S"
proof (cases n)
  show "n = 0  pick S n  S"
    unfolding pick.simps using ‹infinite S LeastI pick.simps(1) by (metis Collect_mem_eq not_finite_existsD)
next
  fix n' assume "n = Suc n'"
  obtain a where "aS  a > pick S n'" using assms by (metis bounded_nat_set_is_finite less_Suc_eq nat_neq_iff)
  show "pick S n  S" unfolding n = Suc n' pick.simps(2)
    using LeastI[of "λa. a  S  pick S n' < a" a, OF aS  a > pick S n'] by blast
qed

lemma pick_mono_inf:
assumes "infinite S"
shows "m < n  pick S m < pick S n"
using assms proof (induction n)
  case 0
  then show ?case by auto
next
  case (Suc n)
  then obtain a where "a  S  pick S n < a" by (metis bounded_nat_set_is_finite less_Suc_eq nat_neq_iff)
  then have "pick S n < pick S (Suc n)" unfolding pick.simps
    using LeastI[of "λa. a  S  pick S n < a" a, OF aS  a > pick S n] by simp
  then show ?case using Suc.IH Suc.prems(1) assms dual_order.strict_trans less_Suc_eq by auto
qed

lemma pick_eq_iff_inf:
assumes "infinite S"
shows "x = y  pick S x = pick S y"
  by (metis assms nat_neq_iff pick_mono_inf)

lemma card_le_pick_inf:
assumes "infinite S"
and "pick S n  i"
shows "card {aS. a < i}  n"
using assms proof (induction n arbitrary:i)
  case 0
  then show ?case unfolding pick.simps using not_less_Least
    by (metis (no_types, lifting) Collect_empty_eq card_0_eq card_ge_0_finite dual_order.strict_trans1 leI le_0_eq)
next
  case (Suc n)
  then show ?case
  proof -
    have "card {a  S. a < pick S n}  n" using Suc by blast
    have "{a  S. a < i}  {a  S. a < pick S (Suc n)}" using Suc.prems(2) by auto
    have "{a  S. a < pick S (Suc n)} = {a  S. a < pick S n}  {pick S n}"
      apply (rule subset_antisym; rule subsetI)
      using not_less_Least UnCI mem_Collect_eq nat_neq_iff singleton_conv
       pick_mono_inf[OF Suc.prems(1), of n "Suc n"] pick_in_set_inf[OF Suc.prems(1), of n] by fastforce+
    then have "card {a  S. a < i}  card {a  S. a < pick S n} + card {pick S n}"
      using card_Un_disjoint  card_mono[OF _ {a  S. a < i}  {a  S. a < pick S (Suc n)}] by simp
    then show ?thesis using ‹card {a  S. a < pick S n}  n  by auto
  qed
qed

lemma card_pick_inf:
assumes "infinite S"
shows "card {aS. a < pick S n} = n"
using assms proof (induction n)
  case 0
  then show ?case unfolding pick.simps using not_less_Least by auto
next
  case (Suc n)
  then show "card {aS. a < pick S (Suc n)} = Suc n"
  proof -
    have "{a  S. a < pick S (Suc n)} = {a  S. a < pick S n}  {pick S n}"
      apply (rule subset_antisym; rule subsetI)
      using not_less_Least UnCI mem_Collect_eq nat_neq_iff singleton_conv
       pick_mono_inf[OF Suc.prems, of n "Suc n"] pick_in_set_inf[OF Suc.prems, of n] by fastforce+
    then have "card {a  S. a < pick S (Suc n)} = card {a  S. a < pick S n} + card {pick S n}"  using card_Un_disjoint by auto
    then show ?thesis by (metis One_nat_def Suc_eq_plus1 Suc card.empty card_insert_if empty_iff finite.emptyI)
  qed
qed

lemma
assumes "n < card S"
shows
  pick_in_set_le:"pick S n  S" and
  card_pick_le: "card {aS. a < pick S n} = n" and
  pick_mono_le: "m < n  pick S m < pick S n"
using assms proof (induction n)
  assume "0 < card S"
  then obtain x where "xS" by fastforce
  then show "pick S 0  S" unfolding pick.simps by (meson LeastI)
  then show "card {a  S. a < pick S 0} = 0" using not_less_Least by auto
  show "m < 0   pick S m < pick S 0" by auto
next
  fix n
  assume "n < card S  pick S n  S"
    and "n < card S  card {a  S. a < pick S n} = n"
    and "Suc n < card S"
    and "m < n  n < card S  pick S m < pick S n"
  then have "card {a  S. a < pick S n} = n" "pick S n  S" by linarith+
  have "card {a  S. a > pick S n} > 0"
  proof -
    have "S = {a  S. a < pick S n}  {a  S. a  pick S n}" by fastforce
    then have "card {a  S. a  pick S n} > 1"
      using ‹Suc n < card S ‹card {a  S. a < pick S n} = n
      card_Un_le[of "{a  S. a < pick S n}" "{a  S. pick S n  a}"] by force
    then have 0:"{a  S. a  pick S n}  {pick S n}  {a  S. a > pick S n}" by auto
    have 1:"finite ({pick S n}  {a  S. pick S n < a})"
      unfolding finite_Un using Collect_mem_eq assms card.infinite conjI by force
    have "1 < card {pick S n} + card {a  S. pick S n < a}"
      using card_mono[OF 1 0] card_Un_le[of "{pick S n}" "{a  S. a > pick S n}"]  ‹card {a  S. a  pick S n} > 1
      by linarith
    then show ?thesis by simp
  qed
  then show "pick S (Suc n)  S" unfolding pick.simps
    by (metis (no_types, lifting) Collect_empty_eq LeastI card_0_eq card.infinite less_numeral_extra(3))
  have "pick S (Suc n) > pick S n"
    by (metis (no_types, lifting) pick.simps(2) ‹card {a  S. a > pick S n} > 0 Collect_empty_eq LeastI card_0_eq card.infinite less_numeral_extra(3))
  then show "m < Suc n  pick S m < pick S (Suc n)"
    using m < n  n < card S  pick S m < pick S n
    using ‹Suc n < card S dual_order.strict_trans less_Suc_eq by auto
  then show "card {aS. a < pick S (Suc n)} = Suc n"
  proof -
    have "{a  S. a < pick S (Suc n)} = {a  S. a < pick S n}  {pick S n}"
      apply (rule subset_antisym; rule subsetI)
      using pick.simps not_less_Least ‹pick S (Suc n) > pick S n ‹pick S n  S by fastforce+
    then have "card {a  S. a < pick S (Suc n)} = card {a  S. a < pick S n} + card {pick S n}"  using card_Un_disjoint by auto
    then show ?thesis by (metis One_nat_def Suc_eq_plus1 ‹card {a  S. a < pick S n} = n card.empty card_insert_if empty_iff finite.emptyI)
  qed
qed

lemma card_le_pick_le:
assumes "n < card S"
and "pick S n  i"
shows "card {aS. a < i}  n"
using assms proof (induction n arbitrary:i)
  case 0
  then show ?case unfolding pick.simps using not_less_Least
    by (metis (no_types, lifting) Collect_empty_eq card_0_eq card_ge_0_finite dual_order.strict_trans1 leI le_0_eq)
next
  case (Suc n)
  have "card {a  S. a < pick S n}  n" using Suc by (simp add: less_eq_Suc_le nat_less_le)
  have "{a  S. a < i}  {a  S. a < pick S (Suc n)}" using Suc.prems(2) by auto
  have "{a  S. a < pick S (Suc n)} = {a  S. a < pick S n}  {pick S n}"
    apply (rule subset_antisym; rule subsetI)
    using pick.simps not_less_Least  pick_mono_le[OF Suc.prems(1), of n, OF lessI] pick_in_set_le[of n S] Suc by fastforce+
  then have "card {a  S. a < i}  card {a  S. a < pick S n} + card {pick S n}"
    using card_Un_disjoint  card_mono[OF _ {a  S. a < i}  {a  S. a < pick S (Suc n)}] by simp
  then show ?case using ‹card {a  S. a < pick S n}  n  by auto
qed

lemma
assumes "n < card S  infinite S"
shows
  pick_in_set:"pick S n  S" and
  card_le_pick: "i  pick S n ==> card {aS. a < i}  n" and
  card_pick: "card {aS. a < pick S n} = n" and
  pick_mono: "m < n  pick S m < pick S n"
    using assms pick_in_set_inf pick_in_set_le card_pick_inf card_pick_le card_le_pick_le card_le_pick_inf
    pick_mono_inf pick_mono_le by auto

lemma pick_card:
"pick I (card {aI. a < i}) = (LEAST a. aI  a  i)"
proof (induction i)
  case 0
  then show ?case by (simp add: pick_in_set_le)
next
  case (Suc i)
  then show ?case
  proof (cases "iI")
    case True
    then have 1:"pick I (card {aI. a < i}) = i" by (metis (mono_tags, lifting) Least_equality Suc.IH order_refl)
    have "{a  I. a < Suc i} = {a  I. a < i}  {i}" using True by auto
    then have 2:"card {a  I. a < Suc i} = Suc (card {a  I. a < i})" by auto
    then show ?thesis unfolding 2 pick.simps 1 using Suc_le_eq by auto
  next
    case False
    then have 1:"{a  I. a < Suc i} = {a  I. a < i}" using Collect_cong less_Suc_eq by auto
    have 2:"a. (a  I  Suc i  a) = (a  I  i  a)" using False Suc_leD le_less_Suc_eq not_le by blast
    then show ?thesis unfolding 1 2 using Suc.IH by blast
  qed
qed

lemma pick_card_in_set: "iI  pick I (card {aI. a < i}) = i"
  unfolding pick_card using Least_equality order_refl by (metis (no_types, lifting))

section "Sublist"

lemma nth_nths_card:
assumes "j<length xs"
and "jJ"
shows "nths xs J ! card {j0. j0 < j  j0  J} = xs!j"
using assms proof (induction xs rule:rev_induct)
  case Nil
  then show ?case using gr_implies_not0 list.size(3) by auto
next
  case (snoc x xs)
  then show ?case
  proof (cases "j < length xs")
    case True
    have "{j0. j0 < j  j0  J}  {i. i < length xs  i  J}"
      using True snoc.prems(2) by auto
    then have "card {j0. j0 < j  j0  J} < length (nths xs J)" unfolding length_nths
      using psubset_card_mono[of "{i. i < length xs  i  J}"] by simp
    then show ?thesis unfolding nths_append nth_append by (simp add: True snoc.IH snoc.prems(2))
  next
    case False
    then have "length xs = j"
      using length_append_singleton less_antisym snoc.prems(1) by auto
    then show ?thesis unfolding nths_append nth_append length_nths ‹length xs = j
      by (simp add: snoc.prems(2))
  qed
qed

lemma pick_reduce_set:
assumes "i<card {a. a<m  aI}"
shows "pick I i = pick {a. a < m  a  I} i"
using assms proof (induction i)
  let ?L = "LEAST a. a  {a. a < m  a  I}"
  case 0
  then have "{a. a < m  a  I}  {}" using card.empty less_numeral_extra(3) by fastforce
  then have "?L  I" "?L < m" by (metis (mono_tags, lifting) Collect_empty_eq LeastI mem_Collect_eq)+
  have "x. x  {a. a < m  a  I}  ?L  x" by (simp add: Least_le)
  have "x. x  I  ?L  x"
    by (metis (mono_tags) ?L < m x. x  {a. a < m  a  I}  ?L  x dual_order.strict_trans2 le_cases mem_Collect_eq)
  then show ?case unfolding pick.simps using Least_equality[of "λx. xI", OF ?L  I] by blast
next
  case (Suc i)
  let ?L = "LEAST x. x  {a. a < m  a  I}  pick I i < x"
  have 0:"pick {a. a < m  a  I} i = pick I i" using Suc_lessD Suc by linarith
  then have "?L  {a. a < m  a  I}" "pick I i < ?L"
    using LeastI[of "λa. a  {a. a < m  a  I}  pick I i < a"] using Suc.prems pick_in_set_le pick_mono_le by fastforce+
  then have "?L  I" by blast
  show ?case unfolding pick.simps 0 using Least_equality[of "λa. a  I  pick I i < a" ?L]
    by (metis (no_types, lifting) Least_le ?L  {a. a < m  a  I} ‹pick I i < ?L mem_Collect_eq not_le not_less_iff_gr_or_eq order.trans)
qed

lemma nth_nths:
assumes "i<card {i. i<length xs  iI}"
shows "nths xs I ! i = xs ! pick I i"
proof -
  have "{a  {i. i < length xs  i  I}. a < pick {i. i < length xs  i  I} i}
        = {a.  a < pick {i. i < length xs  i  I} i  a  I}"
    using assms pick_in_set by fastforce
  then have "card {a. a < pick {i. i < length xs  i  I} i  a  I} = i"
    using card_pick_le[OF assms] by simp
  then have "nths xs I ! i = xs ! pick {i. i < length xs  i  I} i"
    using nth_nths_card[where j = "pick {i. i < length xs  i  I} i", of xs I]
    assms pick_in_set pick_in_set by auto
  then show ?thesis using pick_reduce_set using assms by auto
qed

lemma pick_UNIV: "pick UNIV j = j"
by (induction j, simp, metis (no_types, lifting) LeastI pick.simps(2)  Suc_mono UNIV_I less_Suc_eq not_less_Least)

lemma pick_le:
assumes "n < card {a. a < i  a  S}"
shows "pick S n < i"
proof -
  have 0:"{a  {a. a < i  a  S}. a < i} = {a. a < i  a  S}" by blast
  show ?thesis apply (rule ccontr)
    using card_le_pick_le[OF assms, unfolded pick_reduce_set[OF assms, symmetric], of i, unfolded 0]
    assms not_less not_le by blast
qed

lemma prod_list_complementary_nthss:
fixes f ::"'a  'b::comm_monoid_mult"
shows "prod_list (map f xs) = prod_list (map f (nths xs A)) *  prod_list (map f (nths xs (-A)))"
proof (induction xs rule:rev_induct)
  case Nil
  then show ?case by simp
next
  case (snoc x xs)
  show ?case unfolding map_append "prod_list.append" nths_append nths_singleton snoc
    by (cases "(length xs)A"; simp;metis mult.assoc mult.commute)
qed

lemma nths_zip: "nths (zip xs ys) I = zip (nths xs I) (nths ys I)"
proof (rule nth_equalityI)
  show "length (nths (zip xs ys) I) = length (zip (nths xs I) (nths ys I))"
  proof (cases "length xs  length ys")
    case True
    then have "{i. i < length xs  i  I}  {i. i < length ys  i  I}" by (simp add: Collect_mono less_le_trans)
    then have "card {i. i < length xs  i  I}  card {i. i < length ys  i  I}"
      by (metis (mono_tags, lifting) card_mono finite_nat_set_iff_bounded mem_Collect_eq)
    then show ?thesis unfolding length_nths length_zip using True using min_def by linarith
  next
    case False
    then have "{i. i < length ys  i  I}  {i. i < length xs  i  I}" by (simp add: Collect_mono less_le_trans)
    then have "card {i. i < length ys  i  I}  card {i. i < length xs  i  I}"
      by (metis (mono_tags, lifting) card_mono finite_nat_set_iff_bounded mem_Collect_eq)
    then show ?thesis unfolding length_nths length_zip using False using min_def by linarith
  qed
  show "nths (zip xs ys) I ! i = zip (nths xs I) (nths ys I) ! i" if "i < length (nths (zip xs ys) I)" for i
  proof -
   have "i < length (nths xs I)" "i < length (nths ys I)"
     using that by (simp_all add: ‹length (nths (zip xs ys) I) = length (zip (nths xs I) (nths ys I)))
   show "nths (zip xs ys) I ! i = zip (nths xs I) (nths ys I) ! i"
     unfolding nth_nths[OF i < length (nths (zip xs ys) I)[unfolded length_nths]]
     unfolding nth_zip[OF i < length (nths xs I) i < length (nths ys I)]
     unfolding nth_zip[OF pick_le[OF i < length (nths xs I)[unfolded length_nths]]
                          pick_le[OF i < length (nths ys I)[unfolded length_nths]]]
     by (metis (full_types) i < length (nths xs I) i < length (nths ys I) length_nths nth_nths)
  qed
qed

section "weave"
    
definition weave :: "nat set  'a list  'a list  'a list" where
"weave A xs ys = map (λi. if iA then xs!(card {aA. a<i}) else ys!(card {a-A. a<i})) [0..<length xs + length ys]"

lemma length_weave:
shows "length (weave A xs ys) = length xs + length ys"
unfolding weave_def length_map by simp

lemma nth_weave:
assumes "i < length (weave A xs ys)"
shows "weave A xs ys ! i = (if iA then xs!(card {aA. a<i}) else ys!(card {a-A. a<i}))"
proof -
  have "i < length xs + length ys" using length_weave using assms by metis
  then have "i < length [0..<length xs + length ys]" by auto
  then have "[0..<length xs + length ys] ! i = i"
    by (metis i < length xs + length ys add.left_neutral nth_upt)
  then show ?thesis
    unfolding weave_def nth_map[OF i < length [0..<length xs + length ys]] by presburger
qed

lemma weave_append1:
assumes "length xs + length ys  A"
assumes "length xs = card {aA. a < length xs + length ys}"
shows "weave A (xs @ [x]) ys = weave A xs ys @ [x]"
proof (rule nth_equalityI)
  show "length (weave A (xs @ [x]) ys) = length (weave A xs ys @ [x])"
    unfolding weave_def length_map by simp
  show "weave A (xs @ [x]) ys ! i = (weave A xs ys @ [x]) ! i"
    if "i < length (weave A (xs @ [x]) ys)" for i
  proof -
    show "weave A (xs @ [x]) ys ! i = (weave A xs ys @ [x]) ! i"
    proof (cases "i = length xs + length ys")
      case True
      then have "(weave A xs ys @ [x]) ! i = x" using length_weave by (metis nth_append_length)
      have "card {a  A. a < i} = length xs" using assms(2) True by auto
      then show ?thesis unfolding nth_weave[OF i < length (weave A (xs @ [x]) ys)]
        (weave A xs ys @ [x]) ! i = x using True assms(1) by simp
    next
      case False
      have "i < length (weave A xs ys)" using i < length (weave A (xs @ [x]) ys)
        ‹length (weave A (xs @ [x]) ys) = length (weave A xs ys @ [x]) length_append_singleton
        length_weave less_antisym False by fastforce
      then have "(weave A xs ys @ [x]) ! i = (weave A xs ys) ! i" by (simp add: nth_append)
      {
        assume "iA"
        have  "i<length xs + length ys" by (metis i < length (weave A xs ys) length_weave)
        then have "{a  A. a < i}  {aA. a < length xs + length ys}"
          using assms(1) i<length xs + length ys iA by auto
        then have "card {a  A. a < i} < card {aA. a < length xs + length ys}"
          using psubset_card_mono[of "{aA. a < length xs + length ys}" "{a  A. a < i}"]  by simp
        then have "(xs @ [x]) ! card {a  A. a < i} = xs ! card {a  A. a < i}"
        by (metis (no_types, lifting)  assms(2) nth_append)
      }
      then show ?thesis unfolding nth_weave[OF i < length (weave A (xs @ [x]) ys)]
        (weave A xs ys @ [x]) ! i = (weave A xs ys) ! i nth_weave[OF i < length (weave A xs ys)]
        by simp
    qed
  qed
qed

lemma weave_append2:
assumes "length xs + length ys  A"
assumes "length ys = card {a-A. a < length xs + length ys}"
shows "weave A xs (ys @ [y]) = weave A xs ys @ [y]"
proof (rule nth_equalityI)
  show "length (weave A xs (ys @ [y])) = length (weave A xs ys @ [y])"
    unfolding weave_def length_map by simp
  show "weave A xs (ys @ [y]) ! i = (weave A xs ys @ [y]) ! i" if "i < length (weave A xs (ys @ [y]))" for i
  proof -
    show "weave A xs (ys @ [y]) ! i = (weave A xs ys @ [y]) ! i"
    proof (cases "i = length xs + length ys")
      case True
      then have "(weave A xs ys @ [y]) ! i = y" using length_weave by (metis nth_append_length)
      have "card {a  -A. a < i} = length ys" using assms(2) True by auto
      then show ?thesis unfolding nth_weave[OF i < length (weave A xs (ys @ [y]))]
        (weave A xs ys @ [y]) ! i = y using True assms(1) by simp
    next
      case False
      have "i < length (weave A xs ys)" using i < length (weave A xs (ys @ [y]))
        ‹length (weave A xs (ys @ [y])) = length (weave A xs ys @ [y]) length_append_singleton
        length_weave less_antisym False by fastforce
      then have "(weave A xs ys @ [y]) ! i = (weave A xs ys) ! i" by (simp add: nth_append)
      {
        assume "iA"
        have  "i<length xs + length ys" by (metis i < length (weave A xs ys) length_weave)
        then have "{a  -A. a < i}  {a-A. a < length xs + length ys}"
          using assms(1) i<length xs + length ys iA by auto
        then have "card {a  -A. a < i} < card {a-A. a < length xs + length ys}"
          using psubset_card_mono[of "{a-A. a < length xs + length ys}" "{a  -A. a < i}"]  by simp
        then have "(ys @ [y]) ! card {a  -A. a < i} = ys ! card {a  -A. a < i}"
        by (metis (no_types, lifting)  assms(2) nth_append)
      }
      then show ?thesis unfolding nth_weave[OF i < length (weave A xs (ys @ [y]))]
        (weave A xs ys @ [y]) ! i = (weave A xs ys) ! i nth_weave[OF i < length (weave A xs ys)]
        by simp
    qed
  qed
qed

lemma nths_nth:
assumes "nA" "n<length xs"
shows "nths xs A ! (card {i. i<n  iA}) = xs ! n"
using assms proof (induction xs rule:rev_induct)
  case (snoc x xs)
  then show ?case
  proof (cases "n = length xs")
    case True
    then show ?thesis unfolding nths_append[of xs "[x]" A] nth_append
      using length_nths[of xs A] nths_singleton snoc.prems(1) by auto
  next
    case False
    then have "n < length xs" using snoc by auto
    then have 0:"nths xs A ! card {i. i < n  i  A} = xs ! n" using snoc by auto

    have "{i. i < n  i  A}  {i. i < length xs  i  A}" using n < length xs snoc by force
    then have "card {i. i < n  i  A} < length (nths xs A)" unfolding length_nths
      by (simp add: psubset_card_mono)
    then show ?thesis unfolding nths_append[of xs "[x]" A] nth_append using 0
      by (simp add: n < length xs)
  qed
qed simp

lemma list_all2_nths:
assumes "list_all2 P (nths xs A) (nths ys A)"
and     "list_all2 P (nths xs (-A)) (nths ys (-A))"
shows "list_all2 P xs ys"
proof -
  have "length xs = length ys"
  proof (rule ccontr; cases "length xs < length ys")
    case True
    then show False
    proof (cases "length xs  A")
      case False
      have "{i. i < length xs  i  - A}  {i. i < length ys  i  - A}"
        using False ‹length xs < length ys by force
      then have "length (nths ys (-A)) > length (nths xs (-A))"
        unfolding length_nths by (simp add: psubset_card_mono)
      then show False using assms(2) list_all2_lengthD not_less_iff_gr_or_eq by blast
    next
      case True
      have "{i. i < length xs  i  A}  {i. i < length ys  i  A}"
        using True ‹length xs < length ys by force
      then have "length (nths ys A) > length (nths xs A)"
        unfolding length_nths by (simp add: psubset_card_mono)
      then show False using assms(1) list_all2_lengthD not_less_iff_gr_or_eq by blast
    qed
  next
    assume "length xs  length ys"
    case False
    then have "length xs > length ys" using ‹length xs  length ys by auto
    then show False
    proof (cases "length ys  A")
      case False
      have "{i. i < length ys  i  -A}  {i. i < length xs  i  -A}"
        using False ‹length xs > length ys  by force
      then have "length (nths xs (-A)) > length (nths ys (-A))"
        unfolding length_nths by (simp add: psubset_card_mono)
      then show False using assms(2) list_all2_lengthD dual_order.strict_implies_not_eq by blast
    next
      case True
      have "{i. i < length ys  i  A}  {i. i < length xs  i  A}"
        using True ‹length xs > length ys by force
      then have "length (nths xs A) > length (nths ys A)"
        unfolding length_nths by (simp add: psubset_card_mono)
      then show False using assms(1) list_all2_lengthD dual_order.strict_implies_not_eq by blast
    qed
  qed

  have "n. n < length xs  P (xs ! n) (ys ! n)"
  proof -
    fix n assume "n < length xs"
    then have "n < length ys" using ‹length xs = length ys by auto
    then show "P (xs ! n) (ys ! n)"
    proof (cases "nA")
      case True
      have "{i. i < n  i  A}  {i. i < length xs  i  A}" using n < length xs nA by force
      then have "card {i. i < n  i  A} < length (nths xs A)" unfolding length_nths
        by (simp add: psubset_card_mono)
      show ?thesis using nths_nth[OF nA n < length xs] nths_nth[OF nA n < length ys]
        list_all2_nthD[OF assms(1), of "card {i. i < n  i  A}"] length_nths
        by (simp add: ‹card {i. i < n  i  A} < length (nths xs A))
    next
      case False then have "n-A" by auto
      have "{i. i < n  i  -A}  {i. i < length xs  i  -A}" using n < length xs n-A by force
      then have "card {i. i < n  i  -A} < length (nths xs (-A))" unfolding length_nths
        by (simp add: psubset_card_mono)
      show ?thesis using nths_nth[OF n-A n < length xs] nths_nth[OF n-A n < length ys]
        list_all2_nthD[OF assms(2), of "card {i. i < n  i  -A}"] length_nths
        using ‹card {i. i < n  i  - A} < length (nths xs (- A)) by auto   next
    qed
  qed
  then show ?thesis using ‹length xs = length ys list_all2_all_nthI by blast
qed

lemma nths_weave:
assumes "length xs = card {aA. a < length xs + length ys}"
assumes "length ys = card {a(-A). a < length xs + length ys}"
shows "nths (weave A xs ys) A = xs  nths (weave A xs ys) (-A) = ys"
using assms proof (induction "length xs + length ys" arbitrary: xs ys)
  case 0
  then show ?case
    unfolding weave_def nths_map by simp
next
  case (Suc l)
  then show ?case
  proof (cases "lA")
    case True
    then have "l{a  A. a < length xs + length ys}" using Suc.hyps mem_Collect_eq zero_less_Suc by auto
    then have "length xs > 0" using Suc by fastforce
    then obtain xs' x where "xs = xs' @ [x]" by (metis append_butlast_last_id length_greater_0_conv)
    then have "l = length xs' + length ys" using Suc.hyps by simp
    have length_xs':"length xs' = card {a  A. a < length xs' + length ys}"
    proof -
      have "{a  A. a < length xs + length ys} = {a  A. a < length xs' + length ys}  {l}"
        using xs = xs' @ [x] l{a  A. a < length xs + length ys} l = length xs' + length ys
        by force
      then have "card {a  A. a < length xs + length ys} = card {a  A. a < length xs' + length ys} + 1"
        using l = length xs' + length ys by fastforce
      then show ?thesis by (metis One_nat_def Suc.prems(1) xs = xs' @ [x] add_right_imp_eq
        length_Cons length_append list.size(3))
    qed
    have length_ys:"length ys = card {a  - A. a < length xs' + length ys}"
    proof -
      have "l{a  - A. a < length xs + length ys}" using lA l = length xs' + length ys by blast
      have "{a  -A. a < length xs + length ys} = {a  -A. a < length xs' + length ys}"
        apply (rule subset_antisym)
        using  l = length xs' + length ys ‹Suc l = length xs + length ys l{a  - A. a < length xs + length ys}
        apply (metis (no_types, lifting) Collect_mono less_Suc_eq mem_Collect_eq)
        using Collect_mono Suc.hyps(2) l = length xs' + length ys by auto
      then show ?thesis using Suc.prems(2) by auto
    qed
    have "length xs' + length ys  A" using lA l = length xs' + length ys by blast

    then have "nths (weave A xs ys) A = nths (weave A xs' ys @ [x]) A" unfolding
       xs = xs' @ [x] using weave_append1[OF ‹length xs' + length ys  A length_xs'] by metis
    also have "... = nths (weave A xs' ys) A @ nths [x] {a. a + (length xs' + length ys)  A}"
      using nths_append length_weave by metis
    also have "... = nths (weave A xs' ys) A @ [x]"
      using nths_singleton ‹length xs' + length ys  A by auto
    also have "... = xs" using Suc.hyps(1)[OF l = length xs' + length ys length_xs' length_ys]
     xs = xs' @ [x] by presburger
    finally have "nths (weave A xs ys) A = xs" by metis

    have "nths (weave A xs ys) (-A) = nths (weave A xs' ys @ [x]) (-A)" unfolding
       xs = xs' @ [x] using weave_append1[OF ‹length xs' + length ys  A length_xs'] by metis
    also have "... = nths (weave A xs' ys) (-A) @ nths [x] {a. a + (length xs' + length ys)  (-A)}"
      using nths_append length_weave by metis
    also have "... = nths (weave A xs' ys) (-A)"
      using nths_singleton ‹length xs' + length ys  A by auto
    also have "... = ys"
      using Suc.hyps(1)[OF l = length xs' + length ys length_xs' length_ys] by presburger
    finally show ?thesis using ‹nths (weave A xs ys) A = xs by auto
  next
    case False
    then have "l{a  A. a < length xs + length ys}" using Suc.hyps mem_Collect_eq zero_less_Suc by auto
    then have "length ys > 0" using Suc by fastforce
    then obtain ys' y where "ys = ys' @ [y]" by (metis append_butlast_last_id length_greater_0_conv)
    then have "l = length xs + length ys'" using Suc.hyps by simp
    have length_ys':"length ys' = card {a  -A. a < length xs + length ys'}"
    proof -
      have "{a  -A. a < length xs + length ys} = {a  -A. a < length xs + length ys'}  {l}"
        using ys = ys' @ [y] l{a  A. a < length xs + length ys} l = length xs + length ys'
        by force
      then have "card {a  -A. a < length xs + length ys} = card {a  -A. a < length xs + length ys'} + 1"
        using l = length xs + length ys' by fastforce
      then show ?thesis by (metis One_nat_def Suc.prems(2) ys = ys' @ [y] add_right_imp_eq
        length_Cons length_append list.size(3))
    qed
    have length_xs:"length xs = card {a  A. a < length xs + length ys'}"
    proof -
      have "l{a  A. a < length xs + length ys}" using lA l = length xs + length ys' by blast
      have "{a  A. a < length xs + length ys} = {a  A. a < length xs + length ys'}"
        apply (rule subset_antisym)
        using  l = length xs + length ys' ‹Suc l = length xs + length ys l{a  A. a < length xs + length ys}
        apply (metis (no_types, lifting) Collect_mono less_Suc_eq mem_Collect_eq)
        using Collect_mono Suc.hyps(2) l = length xs + length ys' by auto
      then show ?thesis using Suc.prems(1) by auto
    qed
    have "length xs + length ys'  A" using lA l = length xs + length ys' by blast

    then have "nths (weave A xs ys) A = nths (weave A xs ys' @ [y]) A" unfolding
       ys = ys' @ [y] using weave_append2[OF ‹length xs + length ys'  A length_ys'] by metis
    also have "... = nths (weave A xs ys') A @ nths [y] {a. a + (length xs + length ys')  A}"
      using nths_append length_weave by metis
    also have "... = nths (weave A xs ys') A"
      using nths_singleton ‹length xs + length ys'  A by auto
    also have "... = xs"
      using Suc.hyps(1)[OF l = length xs + length ys' length_xs length_ys'] by auto
    finally have "nths (weave A xs ys) A = xs" by auto

    have "nths (weave A xs ys) (-A) = nths (weave A xs ys' @ [y]) (-A)" unfolding
       ys = ys' @ [y] using weave_append2[OF ‹length xs + length ys'  A length_ys'] by metis
    also have "... = nths (weave A xs ys') (-A) @ nths [y] {a. a + (length xs + length ys')  (-A)}"
      using nths_append length_weave by metis
    also have "... = nths (weave A xs ys') (-A) @ [y]"
      using nths_singleton ‹length xs + length ys'  A by auto
    also have "... = ys"
      using Suc.hyps(1)[OF l = length xs + length ys' length_xs length_ys'] ys = ys' @ [y] by simp
    finally show ?thesis using ‹nths (weave A xs ys) A = xs by auto
  qed
qed

lemma set_weave:
assumes "length xs = card {aA. a < length xs + length ys}"
assumes "length ys = card {a-A. a < length xs + length ys}"
shows "set (weave A xs ys) = set xs  set ys"
proof
  show "set (weave A xs ys)  set xs  set ys"
  proof
    fix x assume "xset (weave A xs ys)"
    then obtain i where "weave A xs ys ! i = x" "i<length (weave A xs ys)" by (meson in_set_conv_nth)
    show "x  set xs  set ys"
    proof (cases "iA")
      case True
      then have "i  {aA. a < length xs + length ys}" unfolding length_weave
        by (metis i < length (weave A xs ys) length_weave mem_Collect_eq)
      then have "{a  A. a < i}  {aA. a < length xs + length ys}"
        using Collect_mono i < length (weave A xs ys)[unfolded length_weave] le_Suc_ex  less_imp_le_nat trans_less_add1
        le_neq_trans less_irrefl mem_Collect_eq by auto
      then have "card {a  A. a < i} < card {aA. a < length xs + length ys}" by (simp add: psubset_card_mono)
      then show "x  set xs  set ys"
        unfolding nth_weave[OF i<length (weave A xs ys), unfolded ‹weave A xs ys ! i = x] using True
        using UnI1 assms(1) nth_mem by auto
    next
      case False
      have "iA  i  {a-A. a < length xs + length ys}" unfolding length_weave
        by (metis ComplI i < length (weave A xs ys) length_weave mem_Collect_eq)
      then have "{a  -A. a < i}  {a-A. a < length xs + length ys}"
        using Collect_mono i < length (weave A xs ys)[unfolded length_weave] le_Suc_ex  less_imp_le_nat trans_less_add1
        le_neq_trans less_irrefl mem_Collect_eq using False by auto
      then have "card {a  -A. a < i} < card {a-A. a < length xs + length ys}" by (simp add: psubset_card_mono)
      then show "x  set xs  set ys"
        unfolding nth_weave[OF i<length (weave A xs ys), unfolded ‹weave A xs ys ! i = x] using False
        using UnI1 assms(2) nth_mem by auto
    qed
  qed
  show "set xs  set ys  set (weave A xs ys)"
    using nths_weave[OF assms] by (metis Un_subset_iff set_nths_subset)
qed


lemma weave_complementary_nthss[simp]:
 "weave A (nths xs A) (nths xs (-A)) = xs"
proof (induction xs rule:rev_induct)
  case Nil
  then show ?case by (metis gen_length_def length_0_conv length_code length_weave nths_nil)
next
  case (snoc x xs)
  have length_xs:"length xs = length (nths xs A) + length (nths xs (-A))" by (metis length_weave snoc.IH)
  show ?case
  proof (cases "(length xs)A")
    case True
    have 0:"length (nths xs A) + length (nths xs (-A))  A" using length_xs True by metis
    have 1:"length (nths xs A) = card {a  A. a < length (nths xs A) + length (nths xs (-A))}"
      using length_nths[of xs A] by (metis (no_types, lifting) Collect_cong length_xs)
    have 2:"nths (xs @ [x]) A = nths xs A @ [x]"
      unfolding nths_append[of xs "[x]" A] using nths_singleton True by auto
    have 3:"nths (xs @ [x]) (-A) = nths xs (-A)"
      unfolding nths_append[of xs "[x]" "-A"] using True by auto
    show ?thesis unfolding 2 3 weave_append1[OF 0 1] snoc.IH by metis
  next
    case False
    have 0:"length (nths xs A) + length (nths xs (-A))  A" using length_xs False by metis
    have 1:"length (nths xs (-A)) = card {a  -A. a < length (nths xs A) + length (nths xs (-A))}"
      using length_nths[of xs "-A"] by (metis (no_types, lifting) Collect_cong length_xs)
    have 2:"nths (xs @ [x]) A = nths xs A"
      unfolding nths_append[of xs "[x]" A] using nths_singleton False by auto
    have 3:"nths (xs @ [x]) (-A) = nths xs (-A) @ [x]"
      unfolding nths_append[of xs "[x]" "-A"] using False by auto
    show ?thesis unfolding 2 3 weave_append2[OF 0 1] snoc.IH by metis
  qed
qed

lemma length_nths': "length (nths xs I) = card {iI. i < length xs}"
unfolding length_nths by meson

end

Theory DL_Submatrix

(* Author: Alexander Bentkamp, Universität des Saarlandes
*)
section ‹Submatrices›

theory DL_Submatrix
imports Matrix DL_Missing_Sublist
begin


section "Submatrix"

definition submatrix :: "'a mat  nat set  nat set  'a mat" where
"submatrix A I J = mat (card {i. i<dim_row A  iI}) (card {j. j<dim_col A  jJ}) (λ(i,j). A $$ (pick I i, pick J j))"

lemma dim_submatrix: "dim_row (submatrix A I J) = card {i. i<dim_row A  iI}"
                     "dim_col (submatrix A I J) = card {j. j<dim_col A  jJ}"
  unfolding submatrix_def by simp_all

lemma submatrix_index:
assumes "i<card {i. i<dim_row A  iI}"
assumes "j<card {j. j<dim_col A  jJ}"
shows "submatrix A I J $$ (i,j) = A $$ (pick I i, pick J j)"
  unfolding submatrix_def by (simp add: assms(1) assms(2))

lemma set_le_in:"{a. a < n  a  I} = {a  I. a < n}" by meson

lemma submatrix_index_card:
assumes "i<dim_row A" "j<dim_col A" "iI" "jJ"
shows "submatrix A I J $$ (card {aI. a < i}, card {aJ. a < j}) = A $$ (i, j)"
proof -
  have "i = pick I (card {aI. a < i})"
       "j = pick J (card {aJ. a < j})" using pick_card_in_set assms by auto
  have "{aI. a < i}  {i. i < dim_row A  i  I}"
       "{aJ. a < j}  {j. j < dim_col A  j  J}"
    unfolding set_le_in using i<dim_row A j<dim_col A Collect_mono less_imp_le less_le_trans iI jJ by auto
  then have "card {aI. a < i} < card {i. i < dim_row A  i  I}"
            "card {aJ. a < j} < card {j. j < dim_col A  j  J}" by (simp_all add: psubset_card_mono)
  then show ?thesis
    using i = pick I (card {a  I. a < i}) j = pick J (card {a  J. a < j}) submatrix_index by fastforce
qed

lemma submatrix_split: "submatrix A I J = submatrix (submatrix A UNIV J) I UNIV"
proof (rule eq_matI)
  show "dim_row (submatrix A I J) = dim_row (submatrix (submatrix A UNIV J) I UNIV)"
    by (simp add: dim_submatrix(1))
  show "dim_col (submatrix A I J) = dim_col (submatrix (submatrix A UNIV J) I UNIV)"
    by (simp add: dim_submatrix(2))
  fix i j assume ij_le:"i < dim_row (submatrix (submatrix A UNIV J) I UNIV)" "j < dim_col (submatrix (submatrix A UNIV J) I UNIV)"
  then have ij_le1:"i<card {i. i < dim_row A  i  I}" "j<card {i. i < dim_col A  i  J}"
    by (simp_all add: dim_submatrix)
  then have ij_le2:"i<card {i. i < dim_row (submatrix A UNIV J)  i  I}" "j<card {i. i < dim_col (submatrix A UNIV J)  i  UNIV}"
    by (simp_all add: dim_submatrix)
  then have i_le3:"pick I i<card {i. i < dim_row A  i  UNIV}"
    using ij_le1(1) pick_le by auto
  have j_le3: "pick UNIV j<card {i. i < dim_col A  i  J}" unfolding pick_UNIV by (simp add: ij_le1(2))
  then show "submatrix A I J $$ (i, j) = submatrix (submatrix A UNIV J) I UNIV $$ (i, j)"
    unfolding submatrix_index[OF ij_le1] submatrix_index[OF ij_le2] submatrix_index[OF i_le3 j_le3]
    unfolding pick_UNIV by metis
qed

end

Theory DL_Rank_Submatrix

(* Author: Alexander Bentkamp, Universität des Saarlandes
*)

section ‹Rank and Submatrices›

theory DL_Rank_Submatrix
imports DL_Rank DL_Submatrix Matrix
begin

lemma row_submatrix_UNIV:
assumes "i < card {i. i < dim_row A  i  I}"
shows "row (submatrix A I UNIV) i = row A (pick I i)"
proof (rule eq_vecI)
  show dim_eq:"dim_vec (row (submatrix A I UNIV) i) = dim_vec (row A (pick I i))"
    unfolding carrier_vecD[OF row_carrier] dim_submatrix by auto
  fix j assume "j < dim_vec (row A (pick I i))"
  then have "j < dim_col (submatrix A I UNIV)" "j < dim_col A" "j < card {j. j < dim_col A  j  UNIV}" using dim_eq by auto
  show "row (submatrix A I UNIV) i $ j = row A (pick I i) $ j"
    unfolding row_def index_vec[OF j < dim_col (submatrix A I UNIV)] index_vec[OF j < dim_col A]
    using submatrix_index[OF assms j < card {j. j < dim_col A  j  UNIV}] using pick_UNIV by auto
qed

lemma distinct_cols_submatrix_UNIV:
assumes "distinct (cols (submatrix A I UNIV))"
shows "distinct (cols A)"
using assms proof (rule contrapos_pp)
  assume "¬ distinct (cols A)"
  then obtain i j where "i < dim_col A" "j < dim_col A" "(cols A)!i = (cols A)!j" "ij"
    using distinct_conv_nth cols_length by metis
  have "i < dim_col (submatrix A I UNIV)" "j < dim_col (submatrix A I UNIV)"
    unfolding dim_submatrix using i < dim_col A j < dim_col Aby simp_all
  then have "i < length (cols (submatrix A I UNIV))" "j < length (cols (submatrix A I UNIV))"
    unfolding cols_length by simp_all
  have "(cols (submatrix A I UNIV))!i = (cols (submatrix A I UNIV))!j"
  proof (rule eq_vecI)
    show "dim_vec (cols (submatrix A I UNIV) ! i) = dim_vec (cols (submatrix A I UNIV) ! j)"
      by (simp add: i < dim_col (submatrix A I UNIV) j < dim_col (submatrix A I UNIV))
    fix k assume "k < dim_vec (cols (submatrix A I UNIV) ! j)"
    then have "k < dim_row (submatrix A I UNIV)"
      using j < length (cols (submatrix A I UNIV))  by auto
    then have  "k < card {j. j < dim_row A  j  I}"  using dim_submatrix(1) by metis
    have i_transfer:"cols (submatrix A I UNIV) ! i $ k = (cols A) ! i $ (pick I k)"
      unfolding cols_nth[OF i < dim_col (submatrix A I UNIV)] col_def index_vec[OF k < dim_row (submatrix A I UNIV)]
      unfolding submatrix_index[OF k < card {j. j < dim_row A  j  I} i < dim_col (submatrix A I UNIV)[unfolded dim_submatrix]]
      unfolding pick_UNIV cols_nth[OF i < dim_col A] col_def index_vec[OF pick_le[OF k < card {j. j < dim_row A  j  I}]]
      by metis
    have j_transfer:"cols (submatrix A I UNIV) ! j $ k = (cols A) ! j $ (pick I k)"
      unfolding cols_nth[OF j < dim_col (submatrix A I UNIV)] col_def index_vec[OF k < dim_row (submatrix A I UNIV)]
      unfolding submatrix_index[OF k < card {j. j < dim_row A  j  I} j < dim_col (submatrix A I UNIV)[unfolded dim_submatrix]]
      unfolding pick_UNIV cols_nth[OF j < dim_col A] col_def index_vec[OF pick_le[OF k < card {j. j < dim_row A  j  I}]]
      by metis
    show "cols (submatrix A I UNIV) ! i $ k = cols (submatrix A I UNIV) ! j $ k"
      using ‹cols A ! i = cols A ! j i_transfer j_transfer by auto
  qed
  then show "¬ distinct (cols (submatrix A I UNIV))" unfolding distinct_conv_nth
    using i < length (cols (submatrix A I UNIV)) j < length (cols (submatrix A I UNIV)) i  j by blast
qed

lemma cols_submatrix_subset: "set (cols (submatrix A UNIV J))  set (cols A)"
proof
  fix c assume "c  set (cols (submatrix A UNIV J))"
  then obtain j where "j < length (cols (submatrix A UNIV J))" "cols (submatrix A UNIV J) ! j = c"
    by (meson in_set_conv_nth)
  then have "j < dim_col (submatrix A UNIV J)" by simp
  then have "j < card {j. j < dim_col A  j  J}" by (simp add: dim_submatrix(2))
  have "cols (submatrix A UNIV J) ! j = cols A ! (pick J j)"
    unfolding cols_nth[OF j < dim_col (submatrix A UNIV J)] cols_nth[OF pick_le[OF j < card {j. j < dim_col A  j  J}]]
  proof (rule eq_vecI)
    show "dim_vec (col (submatrix A UNIV J) j) = dim_vec (col A (pick J j))" unfolding dim_col dim_submatrix by auto
    fix i assume "i < dim_vec (col A (pick J j))"
    then have "i < dim_row A" by simp
    then have "i < dim_row (submatrix A UNIV J)" using ‹dim_vec (col (submatrix A UNIV J) j) = dim_vec (col A (pick J j)) by auto
    show "col (submatrix A UNIV J) j $ i = col A (pick J j) $ i"
      unfolding col_def index_vec[OF i < dim_row (submatrix A UNIV J)] index_vec[OF i < dim_row A]
      using submatrix_index by (metis (no_types, lifting) ‹dim_vec (col (submatrix A UNIV J) j) = dim_vec (col A (pick J j))
      i < dim_vec (col A (pick J j)) j < dim_col (submatrix A UNIV J) dim_col dim_submatrix(1) dim_submatrix(2) pick_UNIV)
  qed
  then show "c  set (cols A)"
    using ‹cols (submatrix A UNIV J) ! j = c
    using pick_le[OF j < card {j. j < dim_col A  j  J}] by (metis cols_length nth_mem)
qed

lemma (in vec_space) lin_dep_submatrix_UNIV:
assumes "A  carrier_mat n nc"
assumes "lin_dep (set (cols A))"
assumes "distinct (cols (submatrix A I UNIV))"
shows "LinearCombinations.module.lin_dep class_ring (module_vec TYPE('a) (card {i. i < n  i  I})) (set (cols (submatrix A I UNIV)))"
  (is "LinearCombinations.module.lin_dep class_ring ?M (set ?S')")
proof -
  obtain v where 2:"v  carrier_vec nc" and 3:"v  0v nc" and "A *v v = 0v n"
    using vec_space.lin_depE[OF assms(1) assms(2) distinct_cols_submatrix_UNIV[OF assms(3)]] by auto
  have 1: "submatrix A I UNIV  carrier_mat (card {i. i < n  i  I}) nc"
    apply (rule carrier_matI) unfolding dim_submatrix using A  carrier_mat n nc by auto
  have 4:"submatrix A I UNIV *v v = 0v (card {i. i < n  i  I})"
  proof (rule eq_vecI)
    show dim_eq:"dim_vec (submatrix A I UNIV *v v) = dim_vec (0v (card {i. i < n  i  I}))" using "1" by auto
    fix i assume "i < dim_vec (0v (card {i. i < n  i  I}))"
    then have i_le:"i < card {i. i < n  i  I}" by auto
    have "(submatrix A I UNIV *v v) $ i = row (submatrix A I UNIV) i  v" using dim_eq i_le by auto
    also have "... = row A (pick I i)  v" using row_submatrix_UNIV
      by (metis (no_types, lifting)  dim_eq dim_mult_mat_vec dim_submatrix(1) i < dim_vec (0v (card {i. i < n  i  I})))
    also have "... = 0"
      using A *v v = 0v n i_le[THEN pick_le] by (metis assms(1) index_mult_mat_vec carrier_matD(1) index_zero_vec(1))
    also have "... = 0v (card {i. i < n  i  I}) $ i" by (simp add: i_le)
    finally show "(submatrix A I UNIV *v v) $ i = 0v (card {i. i < n  i  I}) $ i" by metis
  qed
  show ?thesis using vec_space.lin_depI[OF 1 2 3 4] using assms(3) by auto
qed

lemma (in vec_space) rank_gt_minor:
assumes "A  carrier_mat n nc"
assumes "det (submatrix A I J)  0"
shows "card {j. j < nc  j  J}  rank A"
proof -
  have square:"dim_row (submatrix A I J) = dim_col (submatrix A I J)"
   using det_def ‹det (submatrix A I J)  0 by metis
  then have full_rank:"vec_space.rank (dim_row (submatrix A I J)) (submatrix A I J) = dim_row (submatrix A I J)"
   using vec_space.low_rank_det_zero assms(2) carrier_matI by auto
  then have distinct:"distinct (cols (submatrix A I J))" 
    using vec_space.non_distinct_low_rank square less_irrefl carrier_matI by metis
  then have indpt:"LinearCombinations.module.lin_indpt class_ring (module_vec TYPE('a) (dim_row (submatrix A I J))) (set (cols (submatrix A I J)))"
     using vec_space.full_rank_lin_indpt[OF _ full_rank distinct] square by fastforce

  have distinct2: "distinct (cols (submatrix (submatrix A UNIV J) I UNIV))" using submatrix_split distinct by metis
  have indpt2:"LinearCombinations.module.lin_indpt class_ring (module_vec TYPE('a) (card {i. i < n  i  I})) (set (cols (submatrix (submatrix A UNIV J) I UNIV)))"
    using submatrix_split dim_submatrix(1) indpt by (metis (full_types) assms(1) carrier_matD(1))

  have "submatrix A UNIV J  carrier_mat n (dim_col (submatrix A UNIV J))"
    apply (rule carrier_matI) unfolding dim_submatrix(1) using A  carrier_mat n nc carrier_matD by simp_all
  have "lin_indpt (set (cols (submatrix A UNIV J)))"
    using indpt2 vec_space.lin_dep_submatrix_UNIV[OF ‹submatrix A UNIV J  carrier_mat n (dim_col (submatrix A UNIV J)) _ distinct2] by blast
  have distinct3:"distinct (cols (submatrix A UNIV J))" by (metis distinct distinct_cols_submatrix_UNIV submatrix_split)
  show ?thesis using
    rank_ge_card_indpt[OF A  carrier_mat n nc cols_submatrix_subset ‹lin_indpt (set (cols (submatrix A UNIV J))),
    unfolded distinct_card[OF distinct3, unfolded cols_length dim_submatrix], unfolded carrier_matD(2)[OF A  carrier_mat n nc]]
    by blast
qed

end